583 lines
19 KiB
Python
583 lines
19 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
CMA Code Extraction using Template Matching (Primary Method)
|
||
|
||
This module uses template matching to locate the CMA logo, then extracts
|
||
the CMA code from the region around the logo using OCR.
|
||
|
||
This is the PRIMARY method for CMA extraction, with fallback to full-page OCR.
|
||
|
||
Author: Claude Code
|
||
Date: 2025-02-16
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import cv2
|
||
import numpy as np
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# CMA code patterns
|
||
PATTERN_PRIMARY = r'2[0-9]{10}' # 11 digits starting with 2
|
||
PATTERN_FALLBACK = r'[0-9]{11}' # any 11 digits
|
||
|
||
|
||
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
|
||
"""
|
||
cv2.imread replacement that supports paths with non-ASCII characters.
|
||
|
||
Args:
|
||
path: Image file path (may contain Chinese characters)
|
||
flags: cv2.IMREAD_* flags
|
||
|
||
Returns:
|
||
Image as numpy array or None if failed
|
||
"""
|
||
try:
|
||
data = np.fromfile(str(path), dtype=np.uint8)
|
||
img = cv2.imdecode(data, flags)
|
||
return img
|
||
except Exception as e:
|
||
logger.error(f"Failed to read image {path}: {e}")
|
||
return None
|
||
|
||
|
||
def load_cma_template(template_path='template/CMA_Logo.png'):
|
||
"""
|
||
加载 CMA logo 模板图像
|
||
|
||
Args:
|
||
template_path: 模板图像路径
|
||
|
||
Returns:
|
||
template: 模板图像(灰度)
|
||
template_rgb: 模板图像(RGB,用于可视化)
|
||
"""
|
||
if not os.path.exists(template_path):
|
||
logger.error(f"模板文件不存在: {template_path}")
|
||
return None, None
|
||
|
||
# 读取模板图像(灰度)
|
||
template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
|
||
if template is None:
|
||
logger.error(f"无法读取模板文件: {template_path}")
|
||
return None, None
|
||
|
||
logger.debug(f"加载模板: {template_path}, 尺寸: {template.shape}")
|
||
|
||
return template, template
|
||
|
||
|
||
def match_template(page_img, template, method=cv2.TM_CCOEFF_NORMED):
|
||
"""
|
||
使用 cv2.matchTemplate 进行模板匹配
|
||
|
||
Args:
|
||
page_img: 页面图像(灰度或彩色)
|
||
template: CMA logo 模板(灰度)
|
||
method: 匹配方法(默认 TM_CCOEFF_NORMED)
|
||
|
||
Returns:
|
||
result: 匹配结果字典,包含匹配区域、最大值、位置
|
||
"""
|
||
# 转换为灰度(如果是彩色图像)
|
||
if len(page_img.shape) == 3:
|
||
page_gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY)
|
||
else:
|
||
page_gray = page_img
|
||
|
||
# 执行模板匹配
|
||
result = cv2.matchTemplate(page_gray, template, method=method)
|
||
|
||
if result is None:
|
||
logger.warning("模板匹配失败")
|
||
return None
|
||
|
||
# 获取匹配结果
|
||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
|
||
|
||
# 对于 TM_SQDIFF 方法,最小值是最佳匹配
|
||
if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
|
||
top_left = min_loc
|
||
match_value = 1 - min_val # 转换为相似度
|
||
else:
|
||
top_left = max_loc
|
||
match_value = max_val
|
||
|
||
# 计算匹配区域的中心
|
||
template_h, template_w = template.shape[:2]
|
||
center_x = top_left[0] + template_w // 2
|
||
center_y = top_left[1] + template_h // 2
|
||
|
||
logger.info(f"[TM] Match confidence: {match_value:.3f} (threshold: 0.4)")
|
||
logger.info(f"[TM] Logo detected at center ({center_x}, {center_y}) in image {page_gray.shape[1]}x{page_gray.shape[0]}")
|
||
|
||
return {
|
||
'max_val': float(match_value),
|
||
'top_left': top_left,
|
||
'center': (center_x, center_y),
|
||
'template_size': (template_w, template_h)
|
||
}
|
||
|
||
|
||
def extract_cma_from_roi(roi_img, ocr_engine, output_dir=None, debug_prefix=""):
|
||
"""
|
||
在指定的 ROI 区域内进行 OCR 提取 CMA 码
|
||
|
||
Args:
|
||
roi_img: ROI 区域图像
|
||
ocr_engine: OCR 引擎
|
||
output_dir: 输出目录
|
||
debug_prefix: 调试信息前缀
|
||
|
||
Returns:
|
||
result: 提取结果字典
|
||
"""
|
||
result = {
|
||
'code': None,
|
||
'confidence': 0.0,
|
||
'raw_text': '',
|
||
'position': (0, 0),
|
||
'box': None,
|
||
'success': False
|
||
}
|
||
|
||
if roi_img is None or roi_img.size == 0:
|
||
logger.error(f"{debug_prefix}Invalid ROI image")
|
||
return result
|
||
|
||
h, w = roi_img.shape[:2]
|
||
logger.info(f"{debug_prefix}ROI: (0, 0) -> ({w}, {h})")
|
||
logger.info(f"{debug_prefix}ROI size: {w}x{h}")
|
||
|
||
# 运行 OCR
|
||
try:
|
||
# 检查是否为 PaddleOCRVL
|
||
if hasattr(ocr_engine, 'predict'):
|
||
raw_result = ocr_engine.predict(roi_img)
|
||
else:
|
||
raw_result = ocr_engine.ocr(roi_img)
|
||
|
||
if raw_result is None or len(raw_result) == 0:
|
||
logger.error(f"{debug_prefix}OCR returned empty result")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"{debug_prefix}OCR failed: {e}")
|
||
return result
|
||
|
||
# 处理 OCR 结果
|
||
rec_texts = []
|
||
rec_scores = []
|
||
rec_boxes = []
|
||
|
||
# 检查结果格式
|
||
if isinstance(raw_result[0], dict):
|
||
# 新 API: raw_result[0] 是 OCRResult 对象
|
||
ocr_data = raw_result[0]
|
||
rec_texts = list(ocr_data.get('rec_texts', []))
|
||
rec_scores = list(ocr_data.get('rec_scores', []))
|
||
rec_boxes = list(ocr_data.get('rec_boxes', []))
|
||
logger.info(f"{debug_prefix}Using predict() API format, found {len(rec_texts)} lines")
|
||
elif isinstance(raw_result[0], list):
|
||
# 旧 API: raw_result[0] 是 [ [box, (text, score)], ... ]
|
||
for item in raw_result[0]:
|
||
if item and len(item) >= 2:
|
||
box = item[0]
|
||
text_info = item[1]
|
||
if text_info and len(text_info) >= 2:
|
||
text = text_info[0]
|
||
score = text_info[1]
|
||
|
||
# 计算边界框 (从4个角点)
|
||
if isinstance(box, list) and len(box) >= 4:
|
||
x_coords = [p[0] for p in box]
|
||
y_coords = [p[1] for p in box]
|
||
x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords)
|
||
rec_boxes.append([x1, y1, x2, y2])
|
||
else:
|
||
rec_boxes.append(box)
|
||
|
||
rec_texts.append(text)
|
||
rec_scores.append(score)
|
||
logger.info(f"{debug_prefix}Using legacy ocr() API format, found {len(rec_texts)} lines")
|
||
else:
|
||
logger.warning(f"{debug_prefix}Unknown OCR result format: {type(raw_result[0])}")
|
||
return result
|
||
|
||
if not rec_texts:
|
||
logger.warning(f"{debug_prefix}No text recognized in ROI")
|
||
return result
|
||
|
||
logger.info(f"{debug_prefix}OCR found {len(rec_texts)} text lines")
|
||
|
||
# 打印所有识别的文本(调试)
|
||
for i, (text, score) in enumerate(zip(rec_texts, rec_scores)):
|
||
logger.info(f"{debug_prefix}Line {i}: '{text}' (score: {score:.2f})")
|
||
|
||
# 提取 CMA 码候选
|
||
cma_candidates = []
|
||
|
||
for i, text in enumerate(rec_texts):
|
||
if not text:
|
||
continue
|
||
|
||
# 提取所有数字序列(优先匹配12位,其次是11位)
|
||
numbers = re.findall(r'\d{12}', str(text))
|
||
if not numbers:
|
||
numbers = re.findall(r'\d{11}', str(text))
|
||
|
||
# Debug: print what we found
|
||
if numbers and any('210020349' in n for n in numbers):
|
||
logger.debug(f"[DEBUG] Found numbers in '{text}': {numbers}")
|
||
|
||
for num in numbers:
|
||
# 获取对应的边界框和分数
|
||
box = rec_boxes[i] if i < len(rec_boxes) else None
|
||
score = rec_scores[i] if i < len(rec_scores) else 0.5
|
||
|
||
# 计算位置 (边界框中心)
|
||
if box is not None and len(box) >= 4:
|
||
position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
|
||
else:
|
||
position = (0, 0)
|
||
|
||
cma_candidates.append({
|
||
'code': num,
|
||
'confidence': score,
|
||
'text': str(text),
|
||
'position': position,
|
||
'box': box,
|
||
})
|
||
|
||
# 选择最佳候选
|
||
if cma_candidates:
|
||
# 按分数排序(考虑位置和长度)
|
||
cma_candidates.sort(key=lambda x: (
|
||
x['confidence'] * 100
|
||
+ (30 if x['position'][0] > w / 3 and x['position'][1] < h / 3 else 0) # 右上角加分
|
||
+ (10 if len(x['code']) == 11 else 0)
|
||
- (20 if x['code'].startswith('2') else 0)
|
||
), reverse=True)
|
||
|
||
best = cma_candidates[0]
|
||
result['code'] = best['code']
|
||
result['confidence'] = best['confidence']
|
||
result['raw_text'] = best['text']
|
||
result['position'] = best['position']
|
||
result['box'] = best['box']
|
||
result['success'] = True
|
||
|
||
logger.info(f"{debug_prefix}Best CMA candidate: {best['code']} (conf: {best['confidence']:.2f})")
|
||
else:
|
||
logger.warning(f"{debug_prefix}No CMA code candidates found in ROI text")
|
||
|
||
# 保存可视化结果
|
||
box = result.get('box')
|
||
if output_dir and result['success'] and box is not None:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
vis_roi = roi_img.copy()
|
||
if box is not None and len(box) >= 4:
|
||
# box is [x1, y1, x2, y2] format
|
||
cv2.rectangle(vis_roi, (int(box[0]), int(box[1])),
|
||
(int(box[2]), int(box[3])), (0, 255, 0), 2)
|
||
# 在边界框上方显示文本
|
||
text_pos = (int(box[0]), max(10, int(box[1]) - 10))
|
||
cv2.putText(vis_roi, f"CMA: {result['code']}", text_pos,
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
|
||
cv2.imwrite(os.path.join(output_dir, f"{debug_prefix.strip()}cma_roi_extraction.png"), vis_roi)
|
||
logger.info(f"{debug_prefix}Saved ROI extraction visualization")
|
||
|
||
return result
|
||
|
||
|
||
def extract_cma_code_fullpage(page_img, ocr_engine, template_path='template/CMA_Logo.png',
|
||
output_dir=None, use_template_matching=True):
|
||
"""
|
||
使用模板匹配提取 CMA 码的完整流程
|
||
|
||
Args:
|
||
page_img: 页面图像
|
||
ocr_engine: OCR 引擎
|
||
template_path: CMA logo 模板路径
|
||
output_dir: 输出目录
|
||
use_template_matching: 是否使用模板匹配(False则直接全页OCR)
|
||
|
||
Returns:
|
||
result: CMA 提取结果
|
||
"""
|
||
result = {
|
||
'code': None,
|
||
'confidence': 0.0,
|
||
'raw_text': '',
|
||
'position': (0, 0),
|
||
'box': None,
|
||
'success': False,
|
||
'method': 'none'
|
||
}
|
||
|
||
# 加载图像
|
||
if isinstance(page_img, str):
|
||
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
|
||
elif isinstance(page_img, np.ndarray):
|
||
image = page_img
|
||
else:
|
||
logger.error(f"Invalid image type: {type(page_img)}")
|
||
return result
|
||
|
||
if image is None or image.size == 0:
|
||
logger.error("Failed to load image or empty image")
|
||
return result
|
||
|
||
h, w = image.shape[:2]
|
||
|
||
# 加载模板
|
||
if use_template_matching:
|
||
template, _ = load_cma_template(template_path)
|
||
if template is None:
|
||
logger.warning("Cannot load template, falling back to full-page OCR")
|
||
use_template_matching = False
|
||
|
||
# 方法1: 模板匹配 + ROI OCR
|
||
template_match_success = False
|
||
if use_template_matching:
|
||
logger.info("[TM] Starting template matching extraction...")
|
||
match_result = match_template(image, template)
|
||
|
||
if match_result is None:
|
||
logger.warning("[TM] Template matching failed")
|
||
else:
|
||
match_value = match_result['max_val']
|
||
|
||
# 检查匹配置信度
|
||
if match_value < 0.4:
|
||
logger.warning(f"[TM] Match confidence too low: {match_value:.3f}")
|
||
else:
|
||
# 模板匹配成功,尝试ROI提取
|
||
template_match_success = True
|
||
|
||
# 确定 ROI(关键:ROI 应该在 logo 的右侧,而不是以 logo 为中心)
|
||
center_x, center_y = match_result['center']
|
||
template_w, template_h = match_result['template_size']
|
||
|
||
# 修正:ROI应该在logo的右侧,因为CMA编号通常在logo右边
|
||
# 而不是以logo为中心
|
||
roi_x1 = max(0, center_x) # 从logo中心开始向右
|
||
roi_y1 = max(0, center_y - template_h // 2) # 上下与logo对齐
|
||
roi_x2 = min(w, center_x + min(600, w - center_x)) # 向右扩展最多600px
|
||
roi_y2 = min(h, center_y + template_h // 2 + template_h) # 向下扩展一些
|
||
|
||
# 确保ROI在图像范围内
|
||
roi_x1 = max(roi_x1, 0)
|
||
roi_y1 = max(roi_y1, 0)
|
||
roi_x2 = min(w, roi_x2)
|
||
roi_y2 = min(h, roi_y2)
|
||
|
||
logger.info(f"[TM] ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})")
|
||
|
||
roi_img = image[roi_y1:roi_y2, roi_x1:roi_x2]
|
||
|
||
# 在 ROI 内提取 CMA 码
|
||
result = extract_cma_from_roi(roi_img, ocr_engine, output_dir, debug_prefix="[TM] ")
|
||
|
||
if result['success']:
|
||
result['method'] = 'template_matching'
|
||
logger.info(f"[TM] Template matching SUCCESS: {result['code']} (conf: {result['confidence']:.2f})")
|
||
return result
|
||
else:
|
||
logger.warning("[TM] Template matching found logo, but OCR failed to extract CMA code")
|
||
|
||
# 模板匹配失败,尝试全页OCR作为fallback
|
||
logger.info("[FALLBACK] Template matching failed, trying full-page OCR...")
|
||
result = extract_cma_fullpage_fallback(image, ocr_engine, output_dir)
|
||
result['method'] = 'fullpage_fallback'
|
||
return result
|
||
|
||
|
||
def extract_cma_fullpage_fallback(page_img, ocr_engine, output_dir=None):
|
||
"""
|
||
全页OCR fallback方法 - 当模板匹配失败时使用
|
||
|
||
Args:
|
||
page_img: 页面图像
|
||
ocr_engine: OCR 引擎
|
||
output_dir: 输出目录
|
||
|
||
Returns:
|
||
result: CMA 提取结果
|
||
"""
|
||
result = {
|
||
'code': None,
|
||
'confidence': 0.0,
|
||
'raw_text': '',
|
||
'position': (0, 0),
|
||
'box': None,
|
||
'success': False
|
||
}
|
||
|
||
if isinstance(page_img, str):
|
||
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
|
||
elif isinstance(page_img, np.ndarray):
|
||
image = page_img
|
||
else:
|
||
logger.error(f"Invalid image type: {type(page_img)}")
|
||
return result
|
||
|
||
if image is None or image.size == 0:
|
||
logger.error("Failed to load image or empty image")
|
||
return result
|
||
|
||
h, w = image.shape[:2]
|
||
|
||
# 运行全页OCR
|
||
logger.info("[FALLBACK] Running full-page OCR...")
|
||
try:
|
||
raw_result = ocr_engine.ocr(image)
|
||
except Exception as e:
|
||
logger.error(f"[FALLBACK] OCR failed: {e}")
|
||
return result
|
||
|
||
# 处理OCR结果
|
||
rec_texts = []
|
||
rec_scores = []
|
||
rec_boxes = []
|
||
|
||
if raw_result and len(raw_result) > 0:
|
||
first = raw_result[0]
|
||
if isinstance(first, dict):
|
||
rec_texts = list(first.get('rec_texts', []))
|
||
rec_scores = list(first.get('rec_scores', []))
|
||
rec_boxes = list(first.get('rec_boxes', []))
|
||
elif isinstance(first, list):
|
||
for item in first:
|
||
if item and len(item) >= 2:
|
||
box = item[0]
|
||
text_info = item[1]
|
||
if text_info and len(text_info) >= 2:
|
||
text = text_info[0]
|
||
score = text_info[1]
|
||
|
||
if isinstance(box, list) and len(box) >= 4:
|
||
x_coords = [p[0] for p in box]
|
||
y_coords = [p[1] for p in box]
|
||
x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords)
|
||
rec_boxes.append([x1, y1, x2, y2])
|
||
else:
|
||
rec_boxes.append(box)
|
||
|
||
rec_texts.append(text)
|
||
rec_scores.append(score)
|
||
|
||
logger.info(f"[FALLBACK] Found {len(rec_texts)} text lines")
|
||
|
||
# 提取CMA码候选
|
||
cma_candidates = []
|
||
|
||
for i, text in enumerate(rec_texts):
|
||
if not text:
|
||
continue
|
||
|
||
# 提取所有数字序列(优先匹配12位,其次是11位)
|
||
numbers = re.findall(r'\d{12}', str(text))
|
||
if not numbers:
|
||
numbers = re.findall(r'\d{11}', str(text))
|
||
|
||
for num in numbers:
|
||
box = rec_boxes[i] if i < len(rec_boxes) else None
|
||
score = rec_scores[i] if i < len(rec_scores) else 0.5
|
||
|
||
if box is not None and len(box) >= 4:
|
||
position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
|
||
else:
|
||
position = (0, 0)
|
||
|
||
cma_candidates.append({
|
||
'code': num,
|
||
'confidence': score,
|
||
'text': str(text),
|
||
'position': position,
|
||
'box': box,
|
||
})
|
||
|
||
if not cma_candidates:
|
||
logger.warning("[FALLBACK] No CMA code candidates found")
|
||
return result
|
||
|
||
# 评分和排序(优先右上角,优先以2开头的)
|
||
cma_candidates.sort(key=lambda x: (
|
||
x['confidence'] * 100
|
||
+ (50 if x['code'].startswith('2') else 0) # 以2开头的优先
|
||
+ (30 if x['position'][0] > w / 2 and x['position'][1] < h / 3 else 0) # 右上角加分
|
||
+ (10 if len(x['code']) == 11 else 0)
|
||
), reverse=True)
|
||
|
||
best = cma_candidates[0]
|
||
result['code'] = best['code']
|
||
result['confidence'] = best['confidence']
|
||
result['raw_text'] = best['text']
|
||
result['position'] = best['position']
|
||
result['box'] = best['box']
|
||
result['success'] = True
|
||
|
||
logger.info(f"[FALLBACK] CMA extracted: {best['code']} (conf: {best['confidence']:.2f})")
|
||
|
||
return result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='CMA Logo 模板匹配提取')
|
||
parser.add_argument('--pdf', help='PDF 文件路径')
|
||
parser.add_argument('--template', default='template/CMA_Logo.png', help='CMA logo 模板路径')
|
||
parser.add_argument('--output', default='template_match_debug', help='输出目录')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查文件
|
||
if not os.path.exists(args.pdf):
|
||
print(f"错误: PDF 文件不存在: {args.pdf}")
|
||
sys.exit(1)
|
||
|
||
if not os.path.exists(args.template):
|
||
print(f"错误: 模板文件不存在: {args.template}")
|
||
sys.exit(1)
|
||
|
||
# 加载 OCR 引擎
|
||
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
||
os.environ["PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
||
|
||
from paddleocr import PaddleOCR
|
||
ocr_engine = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False)
|
||
|
||
# 处理 PDF 的第一页
|
||
import fitz
|
||
doc = fitz.open(args.pdf)
|
||
page = doc[0]
|
||
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
|
||
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
|
||
img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||
|
||
print(f"PDF 尺寸: {pix.width}x{pix.height}")
|
||
print(f"图像尺寸: {img_rgb.shape}")
|
||
|
||
# 执行模板匹配提取
|
||
result = extract_cma_code_fullpage(img_rgb, ocr_engine, args.template, args.output)
|
||
|
||
# 输出结果
|
||
print()
|
||
print("="*80)
|
||
print("CMA 提取结果:")
|
||
print("-"*80)
|
||
print(f" 方法: {result.get('method', 'unknown')}")
|
||
print(f" CMA码: {result.get('code', 'N/A')}")
|
||
print(f" 置信度: {result.get('confidence', 0.0):.2f}")
|
||
print(f" 位置: {result.get('position', 'N/A')}")
|
||
print("-"*80)
|
||
print(f" 提取成功: {result.get('success', False)}")
|
||
print("="*80)
|