303 lines
8.4 KiB
Python
303 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
姿态识别Demo
|
||
使用MediaPipe检测人像姿态,支持摄像头实时检测和图片检测
|
||
"""
|
||
|
||
import cv2
|
||
import mediapipe as mp
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
|
||
class PoseDetector:
|
||
def __init__(self):
|
||
"""初始化MediaPipe姿态检测器"""
|
||
self.mp_pose = mp.solutions.pose
|
||
# 创建两个检测器:一个用于静态图像(多人检测),一个用于视频流
|
||
self.pose_static = self.mp_pose.Pose(
|
||
static_image_mode=True,
|
||
model_complexity=1,
|
||
smooth_landmarks=False,
|
||
enable_segmentation=False,
|
||
smooth_segmentation=False,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5
|
||
)
|
||
self.pose_stream = self.mp_pose.Pose(
|
||
static_image_mode=False,
|
||
model_complexity=1,
|
||
smooth_landmarks=True,
|
||
enable_segmentation=False,
|
||
smooth_segmentation=False,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5
|
||
)
|
||
self.mp_draw = mp.solutions.drawing_utils
|
||
|
||
# 姿态连接关系
|
||
self.connections = self.mp_pose.POSE_CONNECTIONS
|
||
|
||
def detect_pose(self, image):
|
||
"""
|
||
检测图像中的人像姿态(单人模式)
|
||
|
||
Args:
|
||
image: 输入图像 (BGR格式)
|
||
|
||
Returns:
|
||
results: MediaPipe姿态检测结果
|
||
"""
|
||
# 转换为RGB格式
|
||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
||
# 进行姿态检测
|
||
results = self.pose_stream.process(image_rgb)
|
||
|
||
return results
|
||
|
||
def detect_poses(self, image):
|
||
"""
|
||
检测图像中的多个人物姿态(多人模式)
|
||
|
||
Args:
|
||
image: 输入图像 (BGR格式)
|
||
|
||
Returns:
|
||
pose_results: 多个人物姿态检测结果的列表
|
||
"""
|
||
# 转换为RGB格式
|
||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
||
# 使用静态模式检测姿态(支持多个人物)
|
||
pose_results = []
|
||
results = self.pose_static.process(image_rgb)
|
||
|
||
# 如果检测到姿态,将其添加到结果列表
|
||
if results.pose_landmarks:
|
||
pose_results.append(results)
|
||
|
||
return pose_results
|
||
|
||
def draw_pose(self, image, results, color=(0, 255, 0)):
|
||
"""
|
||
在图像上绘制姿态骨架(单人)
|
||
|
||
Args:
|
||
image: 输入图像
|
||
results: 姿态检测结果
|
||
color: 骨架颜色
|
||
|
||
Returns:
|
||
image: 绘制了姿态骨架的图像
|
||
"""
|
||
if results.pose_landmarks:
|
||
# 绘制姿态关键点和连接线
|
||
self.mp_draw.draw_landmarks(
|
||
image,
|
||
results.pose_landmarks,
|
||
self.connections,
|
||
landmark_drawing_spec=self.mp_draw.DrawingSpec(
|
||
color=color,
|
||
thickness=2,
|
||
circle_radius=2
|
||
),
|
||
connection_drawing_spec=self.mp_draw.DrawingSpec(
|
||
color=color,
|
||
thickness=2
|
||
)
|
||
)
|
||
return image
|
||
|
||
def draw_poses(self, image, pose_results):
|
||
"""
|
||
在图像上绘制多个姿态骨架
|
||
|
||
Args:
|
||
image: 输入图像
|
||
pose_results: 多个人物姿态检测结果列表
|
||
|
||
Returns:
|
||
image: 绘制了多个姿态骨架的图像
|
||
"""
|
||
# 为不同的人物定义不同的颜色
|
||
colors = [
|
||
(0, 255, 0), # 绿色
|
||
(0, 0, 255), # 红色
|
||
(255, 0, 0), # 蓝色
|
||
(0, 255, 255), # 黄色
|
||
(255, 0, 255), # 紫色
|
||
(255, 255, 0), # 青色
|
||
]
|
||
|
||
for idx, results in enumerate(pose_results):
|
||
color = colors[idx % len(colors)]
|
||
image = self.draw_pose(image, results, color)
|
||
|
||
return image
|
||
|
||
def get_pose_info(self, results):
|
||
"""
|
||
获取姿态信息(单人)
|
||
|
||
Args:
|
||
results: 姿态检测结果
|
||
|
||
Returns:
|
||
pose_info: 姿态信息字典
|
||
"""
|
||
pose_info = {}
|
||
|
||
if results.pose_landmarks:
|
||
landmarks = results.pose_landmarks.landmark
|
||
|
||
# 获取关键部位坐标
|
||
key_points = {
|
||
'nose': 0,
|
||
'left_shoulder': 11,
|
||
'right_shoulder': 12,
|
||
'left_elbow': 13,
|
||
'right_elbow': 14,
|
||
'left_wrist': 15,
|
||
'right_wrist': 16,
|
||
'left_hip': 23,
|
||
'right_hip': 24,
|
||
'left_knee': 25,
|
||
'right_knee': 26,
|
||
'left_ankle': 27,
|
||
'right_ankle': 28
|
||
}
|
||
|
||
for name, idx in key_points.items():
|
||
landmark = landmarks[idx]
|
||
pose_info[name] = {
|
||
'x': landmark.x,
|
||
'y': landmark.y,
|
||
'z': landmark.z,
|
||
'visibility': landmark.visibility
|
||
}
|
||
|
||
return pose_info
|
||
|
||
def get_poses_info(self, pose_results):
|
||
"""
|
||
获取多个人物的姿态信息
|
||
|
||
Args:
|
||
pose_results: 多个人物姿态检测结果列表
|
||
|
||
Returns:
|
||
poses_info: 多个人物姿态信息列表
|
||
"""
|
||
poses_info = []
|
||
|
||
for results in pose_results:
|
||
pose_info = self.get_pose_info(results)
|
||
if pose_info:
|
||
poses_info.append(pose_info)
|
||
|
||
return poses_info
|
||
|
||
|
||
def run_webcam():
|
||
"""运行摄像头实时姿态检测"""
|
||
print("启动摄像头姿态检测...")
|
||
print("按 'q' 键退出")
|
||
|
||
detector = PoseDetector()
|
||
cap = cv2.VideoCapture(0)
|
||
|
||
if not cap.isOpened():
|
||
print("错误: 无法打开摄像头")
|
||
return
|
||
|
||
try:
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print("错误: 无法读取摄像头画面")
|
||
break
|
||
|
||
# 检测姿态
|
||
results = detector.detect_pose(frame)
|
||
|
||
# 绘制姿态
|
||
frame = detector.draw_pose(frame, results)
|
||
|
||
# 显示姿态信息
|
||
if results.pose_landmarks:
|
||
cv2.putText(frame, "Pose Detected", (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||
|
||
# 显示画面
|
||
cv2.imshow('Pose Detection', frame)
|
||
|
||
# 按 'q' 退出
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
finally:
|
||
cap.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
def run_image(image_path):
|
||
"""
|
||
运行图片姿态检测(支持多人)
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
"""
|
||
print(f"检测图片: {image_path}")
|
||
|
||
if not Path(image_path).exists():
|
||
print(f"错误: 文件不存在 - {image_path}")
|
||
return
|
||
|
||
detector = PoseDetector()
|
||
image = cv2.imread(image_path)
|
||
|
||
if image is None:
|
||
print(f"错误: 无法读取图片 - {image_path}")
|
||
return
|
||
|
||
# 检测多个人物姿态
|
||
pose_results = detector.detect_poses(image)
|
||
|
||
# 绘制多个姿态
|
||
image = detector.draw_poses(image, pose_results)
|
||
|
||
# 获取多个人物姿态信息
|
||
poses_info = detector.get_poses_info(pose_results)
|
||
|
||
# 显示图片
|
||
cv2.imshow('Pose Detection', image)
|
||
|
||
# 打印姿态信息
|
||
if poses_info:
|
||
print(f"\n检测到 {len(poses_info)} 个人物:")
|
||
for idx, pose_info in enumerate(poses_info):
|
||
print(f"\n人物 {idx + 1}:")
|
||
for name, info in pose_info.items():
|
||
print(f" {name}: x={info['x']:.3f}, y={info['y']:.3f}, visibility={info['visibility']:.3f}")
|
||
else:
|
||
print("未检测到姿态")
|
||
|
||
print("\n按任意键关闭窗口...")
|
||
cv2.waitKey(0)
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
if len(sys.argv) < 2:
|
||
# 默认使用摄像头
|
||
run_webcam()
|
||
else:
|
||
# 使用图片
|
||
image_path = sys.argv[1]
|
||
run_image(image_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |