report-detect/cma_extraction_template_pri...

583 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CMA Code Extraction using Template Matching (Primary Method)
This module uses template matching to locate the CMA logo, then extracts
the CMA code from the region around the logo using OCR.
This is the PRIMARY method for CMA extraction, with fallback to full-page OCR.
Author: Claude Code
Date: 2025-02-16
"""
import os
import re
import cv2
import numpy as np
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# CMA code patterns
PATTERN_PRIMARY = r'2[0-9]{10}' # 11 digits starting with 2
PATTERN_FALLBACK = r'[0-9]{11}' # any 11 digits
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
"""
cv2.imread replacement that supports paths with non-ASCII characters.
Args:
path: Image file path (may contain Chinese characters)
flags: cv2.IMREAD_* flags
Returns:
Image as numpy array or None if failed
"""
try:
data = np.fromfile(str(path), dtype=np.uint8)
img = cv2.imdecode(data, flags)
return img
except Exception as e:
logger.error(f"Failed to read image {path}: {e}")
return None
def load_cma_template(template_path='template/CMA_Logo.png'):
"""
加载 CMA logo 模板图像
Args:
template_path: 模板图像路径
Returns:
template: 模板图像(灰度)
template_rgb: 模板图像RGB用于可视化
"""
if not os.path.exists(template_path):
logger.error(f"模板文件不存在: {template_path}")
return None, None
# 读取模板图像(灰度)
template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
if template is None:
logger.error(f"无法读取模板文件: {template_path}")
return None, None
logger.debug(f"加载模板: {template_path}, 尺寸: {template.shape}")
return template, template
def match_template(page_img, template, method=cv2.TM_CCOEFF_NORMED):
"""
使用 cv2.matchTemplate 进行模板匹配
Args:
page_img: 页面图像(灰度或彩色)
template: CMA logo 模板(灰度)
method: 匹配方法(默认 TM_CCOEFF_NORMED
Returns:
result: 匹配结果字典,包含匹配区域、最大值、位置
"""
# 转换为灰度(如果是彩色图像)
if len(page_img.shape) == 3:
page_gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY)
else:
page_gray = page_img
# 执行模板匹配
result = cv2.matchTemplate(page_gray, template, method=method)
if result is None:
logger.warning("模板匹配失败")
return None
# 获取匹配结果
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
# 对于 TM_SQDIFF 方法,最小值是最佳匹配
if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
top_left = min_loc
match_value = 1 - min_val # 转换为相似度
else:
top_left = max_loc
match_value = max_val
# 计算匹配区域的中心
template_h, template_w = template.shape[:2]
center_x = top_left[0] + template_w // 2
center_y = top_left[1] + template_h // 2
logger.info(f"[TM] Match confidence: {match_value:.3f} (threshold: 0.4)")
logger.info(f"[TM] Logo detected at center ({center_x}, {center_y}) in image {page_gray.shape[1]}x{page_gray.shape[0]}")
return {
'max_val': float(match_value),
'top_left': top_left,
'center': (center_x, center_y),
'template_size': (template_w, template_h)
}
def extract_cma_from_roi(roi_img, ocr_engine, output_dir=None, debug_prefix=""):
"""
在指定的 ROI 区域内进行 OCR 提取 CMA 码
Args:
roi_img: ROI 区域图像
ocr_engine: OCR 引擎
output_dir: 输出目录
debug_prefix: 调试信息前缀
Returns:
result: 提取结果字典
"""
result = {
'code': None,
'confidence': 0.0,
'raw_text': '',
'position': (0, 0),
'box': None,
'success': False
}
if roi_img is None or roi_img.size == 0:
logger.error(f"{debug_prefix}Invalid ROI image")
return result
h, w = roi_img.shape[:2]
logger.info(f"{debug_prefix}ROI: (0, 0) -> ({w}, {h})")
logger.info(f"{debug_prefix}ROI size: {w}x{h}")
# 运行 OCR
try:
# 检查是否为 PaddleOCRVL
if hasattr(ocr_engine, 'predict'):
raw_result = ocr_engine.predict(roi_img)
else:
raw_result = ocr_engine.ocr(roi_img)
if raw_result is None or len(raw_result) == 0:
logger.error(f"{debug_prefix}OCR returned empty result")
return result
except Exception as e:
logger.error(f"{debug_prefix}OCR failed: {e}")
return result
# 处理 OCR 结果
rec_texts = []
rec_scores = []
rec_boxes = []
# 检查结果格式
if isinstance(raw_result[0], dict):
# 新 API: raw_result[0] 是 OCRResult 对象
ocr_data = raw_result[0]
rec_texts = list(ocr_data.get('rec_texts', []))
rec_scores = list(ocr_data.get('rec_scores', []))
rec_boxes = list(ocr_data.get('rec_boxes', []))
logger.info(f"{debug_prefix}Using predict() API format, found {len(rec_texts)} lines")
elif isinstance(raw_result[0], list):
# 旧 API: raw_result[0] 是 [ [box, (text, score)], ... ]
for item in raw_result[0]:
if item and len(item) >= 2:
box = item[0]
text_info = item[1]
if text_info and len(text_info) >= 2:
text = text_info[0]
score = text_info[1]
# 计算边界框 (从4个角点)
if isinstance(box, list) and len(box) >= 4:
x_coords = [p[0] for p in box]
y_coords = [p[1] for p in box]
x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords)
rec_boxes.append([x1, y1, x2, y2])
else:
rec_boxes.append(box)
rec_texts.append(text)
rec_scores.append(score)
logger.info(f"{debug_prefix}Using legacy ocr() API format, found {len(rec_texts)} lines")
else:
logger.warning(f"{debug_prefix}Unknown OCR result format: {type(raw_result[0])}")
return result
if not rec_texts:
logger.warning(f"{debug_prefix}No text recognized in ROI")
return result
logger.info(f"{debug_prefix}OCR found {len(rec_texts)} text lines")
# 打印所有识别的文本(调试)
for i, (text, score) in enumerate(zip(rec_texts, rec_scores)):
logger.info(f"{debug_prefix}Line {i}: '{text}' (score: {score:.2f})")
# 提取 CMA 码候选
cma_candidates = []
for i, text in enumerate(rec_texts):
if not text:
continue
# 提取所有数字序列优先匹配12位其次是11位
numbers = re.findall(r'\d{12}', str(text))
if not numbers:
numbers = re.findall(r'\d{11}', str(text))
# Debug: print what we found
if numbers and any('210020349' in n for n in numbers):
logger.debug(f"[DEBUG] Found numbers in '{text}': {numbers}")
for num in numbers:
# 获取对应的边界框和分数
box = rec_boxes[i] if i < len(rec_boxes) else None
score = rec_scores[i] if i < len(rec_scores) else 0.5
# 计算位置 (边界框中心)
if box is not None and len(box) >= 4:
position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
else:
position = (0, 0)
cma_candidates.append({
'code': num,
'confidence': score,
'text': str(text),
'position': position,
'box': box,
})
# 选择最佳候选
if cma_candidates:
# 按分数排序(考虑位置和长度)
cma_candidates.sort(key=lambda x: (
x['confidence'] * 100
+ (30 if x['position'][0] > w / 3 and x['position'][1] < h / 3 else 0) # 右上角加分
+ (10 if len(x['code']) == 11 else 0)
- (20 if x['code'].startswith('2') else 0)
), reverse=True)
best = cma_candidates[0]
result['code'] = best['code']
result['confidence'] = best['confidence']
result['raw_text'] = best['text']
result['position'] = best['position']
result['box'] = best['box']
result['success'] = True
logger.info(f"{debug_prefix}Best CMA candidate: {best['code']} (conf: {best['confidence']:.2f})")
else:
logger.warning(f"{debug_prefix}No CMA code candidates found in ROI text")
# 保存可视化结果
box = result.get('box')
if output_dir and result['success'] and box is not None:
os.makedirs(output_dir, exist_ok=True)
vis_roi = roi_img.copy()
if box is not None and len(box) >= 4:
# box is [x1, y1, x2, y2] format
cv2.rectangle(vis_roi, (int(box[0]), int(box[1])),
(int(box[2]), int(box[3])), (0, 255, 0), 2)
# 在边界框上方显示文本
text_pos = (int(box[0]), max(10, int(box[1]) - 10))
cv2.putText(vis_roi, f"CMA: {result['code']}", text_pos,
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
cv2.imwrite(os.path.join(output_dir, f"{debug_prefix.strip()}cma_roi_extraction.png"), vis_roi)
logger.info(f"{debug_prefix}Saved ROI extraction visualization")
return result
def extract_cma_code_fullpage(page_img, ocr_engine, template_path='template/CMA_Logo.png',
output_dir=None, use_template_matching=True):
"""
使用模板匹配提取 CMA 码的完整流程
Args:
page_img: 页面图像
ocr_engine: OCR 引擎
template_path: CMA logo 模板路径
output_dir: 输出目录
use_template_matching: 是否使用模板匹配False则直接全页OCR
Returns:
result: CMA 提取结果
"""
result = {
'code': None,
'confidence': 0.0,
'raw_text': '',
'position': (0, 0),
'box': None,
'success': False,
'method': 'none'
}
# 加载图像
if isinstance(page_img, str):
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
elif isinstance(page_img, np.ndarray):
image = page_img
else:
logger.error(f"Invalid image type: {type(page_img)}")
return result
if image is None or image.size == 0:
logger.error("Failed to load image or empty image")
return result
h, w = image.shape[:2]
# 加载模板
if use_template_matching:
template, _ = load_cma_template(template_path)
if template is None:
logger.warning("Cannot load template, falling back to full-page OCR")
use_template_matching = False
# 方法1: 模板匹配 + ROI OCR
template_match_success = False
if use_template_matching:
logger.info("[TM] Starting template matching extraction...")
match_result = match_template(image, template)
if match_result is None:
logger.warning("[TM] Template matching failed")
else:
match_value = match_result['max_val']
# 检查匹配置信度
if match_value < 0.4:
logger.warning(f"[TM] Match confidence too low: {match_value:.3f}")
else:
# 模板匹配成功尝试ROI提取
template_match_success = True
# 确定 ROI关键ROI 应该在 logo 的右侧,而不是以 logo 为中心)
center_x, center_y = match_result['center']
template_w, template_h = match_result['template_size']
# 修正ROI应该在logo的右侧因为CMA编号通常在logo右边
# 而不是以logo为中心
roi_x1 = max(0, center_x) # 从logo中心开始向右
roi_y1 = max(0, center_y - template_h // 2) # 上下与logo对齐
roi_x2 = min(w, center_x + min(600, w - center_x)) # 向右扩展最多600px
roi_y2 = min(h, center_y + template_h // 2 + template_h) # 向下扩展一些
# 确保ROI在图像范围内
roi_x1 = max(roi_x1, 0)
roi_y1 = max(roi_y1, 0)
roi_x2 = min(w, roi_x2)
roi_y2 = min(h, roi_y2)
logger.info(f"[TM] ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})")
roi_img = image[roi_y1:roi_y2, roi_x1:roi_x2]
# 在 ROI 内提取 CMA 码
result = extract_cma_from_roi(roi_img, ocr_engine, output_dir, debug_prefix="[TM] ")
if result['success']:
result['method'] = 'template_matching'
logger.info(f"[TM] Template matching SUCCESS: {result['code']} (conf: {result['confidence']:.2f})")
return result
else:
logger.warning("[TM] Template matching found logo, but OCR failed to extract CMA code")
# 模板匹配失败尝试全页OCR作为fallback
logger.info("[FALLBACK] Template matching failed, trying full-page OCR...")
result = extract_cma_fullpage_fallback(image, ocr_engine, output_dir)
result['method'] = 'fullpage_fallback'
return result
def extract_cma_fullpage_fallback(page_img, ocr_engine, output_dir=None):
"""
全页OCR fallback方法 - 当模板匹配失败时使用
Args:
page_img: 页面图像
ocr_engine: OCR 引擎
output_dir: 输出目录
Returns:
result: CMA 提取结果
"""
result = {
'code': None,
'confidence': 0.0,
'raw_text': '',
'position': (0, 0),
'box': None,
'success': False
}
if isinstance(page_img, str):
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
elif isinstance(page_img, np.ndarray):
image = page_img
else:
logger.error(f"Invalid image type: {type(page_img)}")
return result
if image is None or image.size == 0:
logger.error("Failed to load image or empty image")
return result
h, w = image.shape[:2]
# 运行全页OCR
logger.info("[FALLBACK] Running full-page OCR...")
try:
raw_result = ocr_engine.ocr(image)
except Exception as e:
logger.error(f"[FALLBACK] OCR failed: {e}")
return result
# 处理OCR结果
rec_texts = []
rec_scores = []
rec_boxes = []
if raw_result and len(raw_result) > 0:
first = raw_result[0]
if isinstance(first, dict):
rec_texts = list(first.get('rec_texts', []))
rec_scores = list(first.get('rec_scores', []))
rec_boxes = list(first.get('rec_boxes', []))
elif isinstance(first, list):
for item in first:
if item and len(item) >= 2:
box = item[0]
text_info = item[1]
if text_info and len(text_info) >= 2:
text = text_info[0]
score = text_info[1]
if isinstance(box, list) and len(box) >= 4:
x_coords = [p[0] for p in box]
y_coords = [p[1] for p in box]
x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords)
rec_boxes.append([x1, y1, x2, y2])
else:
rec_boxes.append(box)
rec_texts.append(text)
rec_scores.append(score)
logger.info(f"[FALLBACK] Found {len(rec_texts)} text lines")
# 提取CMA码候选
cma_candidates = []
for i, text in enumerate(rec_texts):
if not text:
continue
# 提取所有数字序列优先匹配12位其次是11位
numbers = re.findall(r'\d{12}', str(text))
if not numbers:
numbers = re.findall(r'\d{11}', str(text))
for num in numbers:
box = rec_boxes[i] if i < len(rec_boxes) else None
score = rec_scores[i] if i < len(rec_scores) else 0.5
if box is not None and len(box) >= 4:
position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
else:
position = (0, 0)
cma_candidates.append({
'code': num,
'confidence': score,
'text': str(text),
'position': position,
'box': box,
})
if not cma_candidates:
logger.warning("[FALLBACK] No CMA code candidates found")
return result
# 评分和排序优先右上角优先以2开头的
cma_candidates.sort(key=lambda x: (
x['confidence'] * 100
+ (50 if x['code'].startswith('2') else 0) # 以2开头的优先
+ (30 if x['position'][0] > w / 2 and x['position'][1] < h / 3 else 0) # 右上角加分
+ (10 if len(x['code']) == 11 else 0)
), reverse=True)
best = cma_candidates[0]
result['code'] = best['code']
result['confidence'] = best['confidence']
result['raw_text'] = best['text']
result['position'] = best['position']
result['box'] = best['box']
result['success'] = True
logger.info(f"[FALLBACK] CMA extracted: {best['code']} (conf: {best['confidence']:.2f})")
return result
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='CMA Logo 模板匹配提取')
parser.add_argument('--pdf', help='PDF 文件路径')
parser.add_argument('--template', default='template/CMA_Logo.png', help='CMA logo 模板路径')
parser.add_argument('--output', default='template_match_debug', help='输出目录')
args = parser.parse_args()
# 检查文件
if not os.path.exists(args.pdf):
print(f"错误: PDF 文件不存在: {args.pdf}")
sys.exit(1)
if not os.path.exists(args.template):
print(f"错误: 模板文件不存在: {args.template}")
sys.exit(1)
# 加载 OCR 引擎
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
os.environ["PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK"] = "True"
from paddleocr import PaddleOCR
ocr_engine = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False)
# 处理 PDF 的第一页
import fitz
doc = fitz.open(args.pdf)
page = doc[0]
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
print(f"PDF 尺寸: {pix.width}x{pix.height}")
print(f"图像尺寸: {img_rgb.shape}")
# 执行模板匹配提取
result = extract_cma_code_fullpage(img_rgb, ocr_engine, args.template, args.output)
# 输出结果
print()
print("="*80)
print("CMA 提取结果:")
print("-"*80)
print(f" 方法: {result.get('method', 'unknown')}")
print(f" CMA码: {result.get('code', 'N/A')}")
print(f" 置信度: {result.get('confidence', 0.0):.2f}")
print(f" 位置: {result.get('position', 'N/A')}")
print("-"*80)
print(f" 提取成功: {result.get('success', False)}")
print("="*80)