posefind/pose_detector.py

303 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
姿态识别Demo
使用MediaPipe检测人像姿态支持摄像头实时检测和图片检测
"""
import cv2
import mediapipe as mp
import sys
from pathlib import Path
class PoseDetector:
def __init__(self):
"""初始化MediaPipe姿态检测器"""
self.mp_pose = mp.solutions.pose
# 创建两个检测器:一个用于静态图像(多人检测),一个用于视频流
self.pose_static = self.mp_pose.Pose(
static_image_mode=True,
model_complexity=1,
smooth_landmarks=False,
enable_segmentation=False,
smooth_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
self.pose_stream = self.mp_pose.Pose(
static_image_mode=False,
model_complexity=1,
smooth_landmarks=True,
enable_segmentation=False,
smooth_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
self.mp_draw = mp.solutions.drawing_utils
# 姿态连接关系
self.connections = self.mp_pose.POSE_CONNECTIONS
def detect_pose(self, image):
"""
检测图像中的人像姿态(单人模式)
Args:
image: 输入图像 (BGR格式)
Returns:
results: MediaPipe姿态检测结果
"""
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 进行姿态检测
results = self.pose_stream.process(image_rgb)
return results
def detect_poses(self, image):
"""
检测图像中的多个人物姿态(多人模式)
Args:
image: 输入图像 (BGR格式)
Returns:
pose_results: 多个人物姿态检测结果的列表
"""
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 使用静态模式检测姿态(支持多个人物)
pose_results = []
results = self.pose_static.process(image_rgb)
# 如果检测到姿态,将其添加到结果列表
if results.pose_landmarks:
pose_results.append(results)
return pose_results
def draw_pose(self, image, results, color=(0, 255, 0)):
"""
在图像上绘制姿态骨架(单人)
Args:
image: 输入图像
results: 姿态检测结果
color: 骨架颜色
Returns:
image: 绘制了姿态骨架的图像
"""
if results.pose_landmarks:
# 绘制姿态关键点和连接线
self.mp_draw.draw_landmarks(
image,
results.pose_landmarks,
self.connections,
landmark_drawing_spec=self.mp_draw.DrawingSpec(
color=color,
thickness=2,
circle_radius=2
),
connection_drawing_spec=self.mp_draw.DrawingSpec(
color=color,
thickness=2
)
)
return image
def draw_poses(self, image, pose_results):
"""
在图像上绘制多个姿态骨架
Args:
image: 输入图像
pose_results: 多个人物姿态检测结果列表
Returns:
image: 绘制了多个姿态骨架的图像
"""
# 为不同的人物定义不同的颜色
colors = [
(0, 255, 0), # 绿色
(0, 0, 255), # 红色
(255, 0, 0), # 蓝色
(0, 255, 255), # 黄色
(255, 0, 255), # 紫色
(255, 255, 0), # 青色
]
for idx, results in enumerate(pose_results):
color = colors[idx % len(colors)]
image = self.draw_pose(image, results, color)
return image
def get_pose_info(self, results):
"""
获取姿态信息(单人)
Args:
results: 姿态检测结果
Returns:
pose_info: 姿态信息字典
"""
pose_info = {}
if results.pose_landmarks:
landmarks = results.pose_landmarks.landmark
# 获取关键部位坐标
key_points = {
'nose': 0,
'left_shoulder': 11,
'right_shoulder': 12,
'left_elbow': 13,
'right_elbow': 14,
'left_wrist': 15,
'right_wrist': 16,
'left_hip': 23,
'right_hip': 24,
'left_knee': 25,
'right_knee': 26,
'left_ankle': 27,
'right_ankle': 28
}
for name, idx in key_points.items():
landmark = landmarks[idx]
pose_info[name] = {
'x': landmark.x,
'y': landmark.y,
'z': landmark.z,
'visibility': landmark.visibility
}
return pose_info
def get_poses_info(self, pose_results):
"""
获取多个人物的姿态信息
Args:
pose_results: 多个人物姿态检测结果列表
Returns:
poses_info: 多个人物姿态信息列表
"""
poses_info = []
for results in pose_results:
pose_info = self.get_pose_info(results)
if pose_info:
poses_info.append(pose_info)
return poses_info
def run_webcam():
"""运行摄像头实时姿态检测"""
print("启动摄像头姿态检测...")
print("'q' 键退出")
detector = PoseDetector()
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("错误: 无法打开摄像头")
return
try:
while True:
ret, frame = cap.read()
if not ret:
print("错误: 无法读取摄像头画面")
break
# 检测姿态
results = detector.detect_pose(frame)
# 绘制姿态
frame = detector.draw_pose(frame, results)
# 显示姿态信息
if results.pose_landmarks:
cv2.putText(frame, "Pose Detected", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 显示画面
cv2.imshow('Pose Detection', frame)
# 按 'q' 退出
if cv2.waitKey(1) & 0xFF == ord('q'):
break
finally:
cap.release()
cv2.destroyAllWindows()
def run_image(image_path):
"""
运行图片姿态检测(支持多人)
Args:
image_path: 图片路径
"""
print(f"检测图片: {image_path}")
if not Path(image_path).exists():
print(f"错误: 文件不存在 - {image_path}")
return
detector = PoseDetector()
image = cv2.imread(image_path)
if image is None:
print(f"错误: 无法读取图片 - {image_path}")
return
# 检测多个人物姿态
pose_results = detector.detect_poses(image)
# 绘制多个姿态
image = detector.draw_poses(image, pose_results)
# 获取多个人物姿态信息
poses_info = detector.get_poses_info(pose_results)
# 显示图片
cv2.imshow('Pose Detection', image)
# 打印姿态信息
if poses_info:
print(f"\n检测到 {len(poses_info)} 个人物:")
for idx, pose_info in enumerate(poses_info):
print(f"\n人物 {idx + 1}:")
for name, info in pose_info.items():
print(f" {name}: x={info['x']:.3f}, y={info['y']:.3f}, visibility={info['visibility']:.3f}")
else:
print("未检测到姿态")
print("\n按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
"""主函数"""
if len(sys.argv) < 2:
# 默认使用摄像头
run_webcam()
else:
# 使用图片
image_path = sys.argv[1]
run_image(image_path)
if __name__ == "__main__":
main()