posefind/pose_detector.py

303 lines
8.4 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
姿态识别Demo
使用MediaPipe检测人像姿态支持摄像头实时检测和图片检测
"""
import cv2
import mediapipe as mp
import sys
from pathlib import Path
class PoseDetector:
def __init__(self):
"""初始化MediaPipe姿态检测器"""
self.mp_pose = mp.solutions.pose
# 创建两个检测器:一个用于静态图像(多人检测),一个用于视频流
self.pose_static = self.mp_pose.Pose(
static_image_mode=True,
model_complexity=1,
smooth_landmarks=False,
enable_segmentation=False,
smooth_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
self.pose_stream = self.mp_pose.Pose(
static_image_mode=False,
model_complexity=1,
smooth_landmarks=True,
enable_segmentation=False,
smooth_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
self.mp_draw = mp.solutions.drawing_utils
# 姿态连接关系
self.connections = self.mp_pose.POSE_CONNECTIONS
def detect_pose(self, image):
"""
检测图像中的人像姿态单人模式
Args:
image: 输入图像 (BGR格式)
Returns:
results: MediaPipe姿态检测结果
"""
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 进行姿态检测
results = self.pose_stream.process(image_rgb)
return results
def detect_poses(self, image):
"""
检测图像中的多个人物姿态多人模式
Args:
image: 输入图像 (BGR格式)
Returns:
pose_results: 多个人物姿态检测结果的列表
"""
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 使用静态模式检测姿态(支持多个人物)
pose_results = []
results = self.pose_static.process(image_rgb)
# 如果检测到姿态,将其添加到结果列表
if results.pose_landmarks:
pose_results.append(results)
return pose_results
def draw_pose(self, image, results, color=(0, 255, 0)):
"""
在图像上绘制姿态骨架单人
Args:
image: 输入图像
results: 姿态检测结果
color: 骨架颜色
Returns:
image: 绘制了姿态骨架的图像
"""
if results.pose_landmarks:
# 绘制姿态关键点和连接线
self.mp_draw.draw_landmarks(
image,
results.pose_landmarks,
self.connections,
landmark_drawing_spec=self.mp_draw.DrawingSpec(
color=color,
thickness=2,
circle_radius=2
),
connection_drawing_spec=self.mp_draw.DrawingSpec(
color=color,
thickness=2
)
)
return image
def draw_poses(self, image, pose_results):
"""
在图像上绘制多个姿态骨架
Args:
image: 输入图像
pose_results: 多个人物姿态检测结果列表
Returns:
image: 绘制了多个姿态骨架的图像
"""
# 为不同的人物定义不同的颜色
colors = [
(0, 255, 0), # 绿色
(0, 0, 255), # 红色
(255, 0, 0), # 蓝色
(0, 255, 255), # 黄色
(255, 0, 255), # 紫色
(255, 255, 0), # 青色
]
for idx, results in enumerate(pose_results):
color = colors[idx % len(colors)]
image = self.draw_pose(image, results, color)
return image
def get_pose_info(self, results):
"""
获取姿态信息单人
Args:
results: 姿态检测结果
Returns:
pose_info: 姿态信息字典
"""
pose_info = {}
if results.pose_landmarks:
landmarks = results.pose_landmarks.landmark
# 获取关键部位坐标
key_points = {
'nose': 0,
'left_shoulder': 11,
'right_shoulder': 12,
'left_elbow': 13,
'right_elbow': 14,
'left_wrist': 15,
'right_wrist': 16,
'left_hip': 23,
'right_hip': 24,
'left_knee': 25,
'right_knee': 26,
'left_ankle': 27,
'right_ankle': 28
}
for name, idx in key_points.items():
landmark = landmarks[idx]
pose_info[name] = {
'x': landmark.x,
'y': landmark.y,
'z': landmark.z,
'visibility': landmark.visibility
}
return pose_info
def get_poses_info(self, pose_results):
"""
获取多个人物的姿态信息
Args:
pose_results: 多个人物姿态检测结果列表
Returns:
poses_info: 多个人物姿态信息列表
"""
poses_info = []
for results in pose_results:
pose_info = self.get_pose_info(results)
if pose_info:
poses_info.append(pose_info)
return poses_info
def run_webcam():
"""运行摄像头实时姿态检测"""
print("启动摄像头姿态检测...")
print("'q' 键退出")
detector = PoseDetector()
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("错误: 无法打开摄像头")
return
try:
while True:
ret, frame = cap.read()
if not ret:
print("错误: 无法读取摄像头画面")
break
# 检测姿态
results = detector.detect_pose(frame)
# 绘制姿态
frame = detector.draw_pose(frame, results)
# 显示姿态信息
if results.pose_landmarks:
cv2.putText(frame, "Pose Detected", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 显示画面
cv2.imshow('Pose Detection', frame)
# 按 'q' 退出
if cv2.waitKey(1) & 0xFF == ord('q'):
break
finally:
cap.release()
cv2.destroyAllWindows()
def run_image(image_path):
"""
运行图片姿态检测支持多人
Args:
image_path: 图片路径
"""
print(f"检测图片: {image_path}")
if not Path(image_path).exists():
print(f"错误: 文件不存在 - {image_path}")
return
detector = PoseDetector()
image = cv2.imread(image_path)
if image is None:
print(f"错误: 无法读取图片 - {image_path}")
return
# 检测多个人物姿态
pose_results = detector.detect_poses(image)
# 绘制多个姿态
image = detector.draw_poses(image, pose_results)
# 获取多个人物姿态信息
poses_info = detector.get_poses_info(pose_results)
# 显示图片
cv2.imshow('Pose Detection', image)
# 打印姿态信息
if poses_info:
print(f"\n检测到 {len(poses_info)} 个人物:")
for idx, pose_info in enumerate(poses_info):
print(f"\n人物 {idx + 1}:")
for name, info in pose_info.items():
print(f" {name}: x={info['x']:.3f}, y={info['y']:.3f}, visibility={info['visibility']:.3f}")
else:
print("未检测到姿态")
print("\n按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
"""主函数"""
if len(sys.argv) < 2:
# 默认使用摄像头
run_webcam()
else:
# 使用图片
image_path = sys.argv[1]
run_image(image_path)
if __name__ == "__main__":
main()