303 lines
8.4 KiB
Python
303 lines
8.4 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
姿态识别Demo
|
|||
|
|
使用MediaPipe检测人像姿态,支持摄像头实时检测和图片检测
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import mediapipe as mp
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PoseDetector:
|
|||
|
|
def __init__(self):
|
|||
|
|
"""初始化MediaPipe姿态检测器"""
|
|||
|
|
self.mp_pose = mp.solutions.pose
|
|||
|
|
# 创建两个检测器:一个用于静态图像(多人检测),一个用于视频流
|
|||
|
|
self.pose_static = self.mp_pose.Pose(
|
|||
|
|
static_image_mode=True,
|
|||
|
|
model_complexity=1,
|
|||
|
|
smooth_landmarks=False,
|
|||
|
|
enable_segmentation=False,
|
|||
|
|
smooth_segmentation=False,
|
|||
|
|
min_detection_confidence=0.5,
|
|||
|
|
min_tracking_confidence=0.5
|
|||
|
|
)
|
|||
|
|
self.pose_stream = self.mp_pose.Pose(
|
|||
|
|
static_image_mode=False,
|
|||
|
|
model_complexity=1,
|
|||
|
|
smooth_landmarks=True,
|
|||
|
|
enable_segmentation=False,
|
|||
|
|
smooth_segmentation=False,
|
|||
|
|
min_detection_confidence=0.5,
|
|||
|
|
min_tracking_confidence=0.5
|
|||
|
|
)
|
|||
|
|
self.mp_draw = mp.solutions.drawing_utils
|
|||
|
|
|
|||
|
|
# 姿态连接关系
|
|||
|
|
self.connections = self.mp_pose.POSE_CONNECTIONS
|
|||
|
|
|
|||
|
|
def detect_pose(self, image):
|
|||
|
|
"""
|
|||
|
|
检测图像中的人像姿态(单人模式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
image: 输入图像 (BGR格式)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
results: MediaPipe姿态检测结果
|
|||
|
|
"""
|
|||
|
|
# 转换为RGB格式
|
|||
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|||
|
|
|
|||
|
|
# 进行姿态检测
|
|||
|
|
results = self.pose_stream.process(image_rgb)
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
def detect_poses(self, image):
|
|||
|
|
"""
|
|||
|
|
检测图像中的多个人物姿态(多人模式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
image: 输入图像 (BGR格式)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
pose_results: 多个人物姿态检测结果的列表
|
|||
|
|
"""
|
|||
|
|
# 转换为RGB格式
|
|||
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|||
|
|
|
|||
|
|
# 使用静态模式检测姿态(支持多个人物)
|
|||
|
|
pose_results = []
|
|||
|
|
results = self.pose_static.process(image_rgb)
|
|||
|
|
|
|||
|
|
# 如果检测到姿态,将其添加到结果列表
|
|||
|
|
if results.pose_landmarks:
|
|||
|
|
pose_results.append(results)
|
|||
|
|
|
|||
|
|
return pose_results
|
|||
|
|
|
|||
|
|
def draw_pose(self, image, results, color=(0, 255, 0)):
|
|||
|
|
"""
|
|||
|
|
在图像上绘制姿态骨架(单人)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
image: 输入图像
|
|||
|
|
results: 姿态检测结果
|
|||
|
|
color: 骨架颜色
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
image: 绘制了姿态骨架的图像
|
|||
|
|
"""
|
|||
|
|
if results.pose_landmarks:
|
|||
|
|
# 绘制姿态关键点和连接线
|
|||
|
|
self.mp_draw.draw_landmarks(
|
|||
|
|
image,
|
|||
|
|
results.pose_landmarks,
|
|||
|
|
self.connections,
|
|||
|
|
landmark_drawing_spec=self.mp_draw.DrawingSpec(
|
|||
|
|
color=color,
|
|||
|
|
thickness=2,
|
|||
|
|
circle_radius=2
|
|||
|
|
),
|
|||
|
|
connection_drawing_spec=self.mp_draw.DrawingSpec(
|
|||
|
|
color=color,
|
|||
|
|
thickness=2
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
return image
|
|||
|
|
|
|||
|
|
def draw_poses(self, image, pose_results):
|
|||
|
|
"""
|
|||
|
|
在图像上绘制多个姿态骨架
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
image: 输入图像
|
|||
|
|
pose_results: 多个人物姿态检测结果列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
image: 绘制了多个姿态骨架的图像
|
|||
|
|
"""
|
|||
|
|
# 为不同的人物定义不同的颜色
|
|||
|
|
colors = [
|
|||
|
|
(0, 255, 0), # 绿色
|
|||
|
|
(0, 0, 255), # 红色
|
|||
|
|
(255, 0, 0), # 蓝色
|
|||
|
|
(0, 255, 255), # 黄色
|
|||
|
|
(255, 0, 255), # 紫色
|
|||
|
|
(255, 255, 0), # 青色
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for idx, results in enumerate(pose_results):
|
|||
|
|
color = colors[idx % len(colors)]
|
|||
|
|
image = self.draw_pose(image, results, color)
|
|||
|
|
|
|||
|
|
return image
|
|||
|
|
|
|||
|
|
def get_pose_info(self, results):
|
|||
|
|
"""
|
|||
|
|
获取姿态信息(单人)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
results: 姿态检测结果
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
pose_info: 姿态信息字典
|
|||
|
|
"""
|
|||
|
|
pose_info = {}
|
|||
|
|
|
|||
|
|
if results.pose_landmarks:
|
|||
|
|
landmarks = results.pose_landmarks.landmark
|
|||
|
|
|
|||
|
|
# 获取关键部位坐标
|
|||
|
|
key_points = {
|
|||
|
|
'nose': 0,
|
|||
|
|
'left_shoulder': 11,
|
|||
|
|
'right_shoulder': 12,
|
|||
|
|
'left_elbow': 13,
|
|||
|
|
'right_elbow': 14,
|
|||
|
|
'left_wrist': 15,
|
|||
|
|
'right_wrist': 16,
|
|||
|
|
'left_hip': 23,
|
|||
|
|
'right_hip': 24,
|
|||
|
|
'left_knee': 25,
|
|||
|
|
'right_knee': 26,
|
|||
|
|
'left_ankle': 27,
|
|||
|
|
'right_ankle': 28
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for name, idx in key_points.items():
|
|||
|
|
landmark = landmarks[idx]
|
|||
|
|
pose_info[name] = {
|
|||
|
|
'x': landmark.x,
|
|||
|
|
'y': landmark.y,
|
|||
|
|
'z': landmark.z,
|
|||
|
|
'visibility': landmark.visibility
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return pose_info
|
|||
|
|
|
|||
|
|
def get_poses_info(self, pose_results):
|
|||
|
|
"""
|
|||
|
|
获取多个人物的姿态信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pose_results: 多个人物姿态检测结果列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
poses_info: 多个人物姿态信息列表
|
|||
|
|
"""
|
|||
|
|
poses_info = []
|
|||
|
|
|
|||
|
|
for results in pose_results:
|
|||
|
|
pose_info = self.get_pose_info(results)
|
|||
|
|
if pose_info:
|
|||
|
|
poses_info.append(pose_info)
|
|||
|
|
|
|||
|
|
return poses_info
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_webcam():
|
|||
|
|
"""运行摄像头实时姿态检测"""
|
|||
|
|
print("启动摄像头姿态检测...")
|
|||
|
|
print("按 'q' 键退出")
|
|||
|
|
|
|||
|
|
detector = PoseDetector()
|
|||
|
|
cap = cv2.VideoCapture(0)
|
|||
|
|
|
|||
|
|
if not cap.isOpened():
|
|||
|
|
print("错误: 无法打开摄像头")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
ret, frame = cap.read()
|
|||
|
|
if not ret:
|
|||
|
|
print("错误: 无法读取摄像头画面")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 检测姿态
|
|||
|
|
results = detector.detect_pose(frame)
|
|||
|
|
|
|||
|
|
# 绘制姿态
|
|||
|
|
frame = detector.draw_pose(frame, results)
|
|||
|
|
|
|||
|
|
# 显示姿态信息
|
|||
|
|
if results.pose_landmarks:
|
|||
|
|
cv2.putText(frame, "Pose Detected", (10, 30),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|||
|
|
|
|||
|
|
# 显示画面
|
|||
|
|
cv2.imshow('Pose Detection', frame)
|
|||
|
|
|
|||
|
|
# 按 'q' 退出
|
|||
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
cap.release()
|
|||
|
|
cv2.destroyAllWindows()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_image(image_path):
|
|||
|
|
"""
|
|||
|
|
运行图片姿态检测(支持多人)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
image_path: 图片路径
|
|||
|
|
"""
|
|||
|
|
print(f"检测图片: {image_path}")
|
|||
|
|
|
|||
|
|
if not Path(image_path).exists():
|
|||
|
|
print(f"错误: 文件不存在 - {image_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
detector = PoseDetector()
|
|||
|
|
image = cv2.imread(image_path)
|
|||
|
|
|
|||
|
|
if image is None:
|
|||
|
|
print(f"错误: 无法读取图片 - {image_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 检测多个人物姿态
|
|||
|
|
pose_results = detector.detect_poses(image)
|
|||
|
|
|
|||
|
|
# 绘制多个姿态
|
|||
|
|
image = detector.draw_poses(image, pose_results)
|
|||
|
|
|
|||
|
|
# 获取多个人物姿态信息
|
|||
|
|
poses_info = detector.get_poses_info(pose_results)
|
|||
|
|
|
|||
|
|
# 显示图片
|
|||
|
|
cv2.imshow('Pose Detection', image)
|
|||
|
|
|
|||
|
|
# 打印姿态信息
|
|||
|
|
if poses_info:
|
|||
|
|
print(f"\n检测到 {len(poses_info)} 个人物:")
|
|||
|
|
for idx, pose_info in enumerate(poses_info):
|
|||
|
|
print(f"\n人物 {idx + 1}:")
|
|||
|
|
for name, info in pose_info.items():
|
|||
|
|
print(f" {name}: x={info['x']:.3f}, y={info['y']:.3f}, visibility={info['visibility']:.3f}")
|
|||
|
|
else:
|
|||
|
|
print("未检测到姿态")
|
|||
|
|
|
|||
|
|
print("\n按任意键关闭窗口...")
|
|||
|
|
cv2.waitKey(0)
|
|||
|
|
cv2.destroyAllWindows()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
if len(sys.argv) < 2:
|
|||
|
|
# 默认使用摄像头
|
|||
|
|
run_webcam()
|
|||
|
|
else:
|
|||
|
|
# 使用图片
|
|||
|
|
image_path = sys.argv[1]
|
|||
|
|
run_image(image_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|