#!/usr/bin/env python # -*- coding: utf-8 -*- """ CMA Code Extraction using Template Matching (Primary Method) This module uses template matching to locate the CMA logo, then extracts the CMA code from the region around the logo using OCR. This is the PRIMARY method for CMA extraction, with fallback to full-page OCR. Author: Claude Code Date: 2025-02-16 """ import os import re import cv2 import numpy as np import logging from pathlib import Path logger = logging.getLogger(__name__) # CMA code patterns PATTERN_PRIMARY = r'2[0-9]{10}' # 11 digits starting with 2 PATTERN_FALLBACK = r'[0-9]{11}' # any 11 digits def imread_unicode(path, flags=cv2.IMREAD_COLOR): """ cv2.imread replacement that supports paths with non-ASCII characters. Args: path: Image file path (may contain Chinese characters) flags: cv2.IMREAD_* flags Returns: Image as numpy array or None if failed """ try: data = np.fromfile(str(path), dtype=np.uint8) img = cv2.imdecode(data, flags) return img except Exception as e: logger.error(f"Failed to read image {path}: {e}") return None def load_cma_template(template_path='template/CMA_Logo.png'): """ 加载 CMA logo 模板图像 Args: template_path: 模板图像路径 Returns: template: 模板图像(灰度) template_rgb: 模板图像(RGB,用于可视化) """ if not os.path.exists(template_path): logger.error(f"模板文件不存在: {template_path}") return None, None # 读取模板图像(灰度) template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE) if template is None: logger.error(f"无法读取模板文件: {template_path}") return None, None logger.debug(f"加载模板: {template_path}, 尺寸: {template.shape}") return template, template def match_template(page_img, template, method=cv2.TM_CCOEFF_NORMED): """ 使用 cv2.matchTemplate 进行模板匹配 Args: page_img: 页面图像(灰度或彩色) template: CMA logo 模板(灰度) method: 匹配方法(默认 TM_CCOEFF_NORMED) Returns: result: 匹配结果字典,包含匹配区域、最大值、位置 """ # 转换为灰度(如果是彩色图像) if len(page_img.shape) == 3: page_gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY) else: page_gray = page_img # 执行模板匹配 result = cv2.matchTemplate(page_gray, template, method=method) if result is None: logger.warning("模板匹配失败") return None # 获取匹配结果 min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result) # 对于 TM_SQDIFF 方法,最小值是最佳匹配 if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]: top_left = min_loc match_value = 1 - min_val # 转换为相似度 else: top_left = max_loc match_value = max_val # 计算匹配区域的中心 template_h, template_w = template.shape[:2] center_x = top_left[0] + template_w // 2 center_y = top_left[1] + template_h // 2 logger.info(f"[TM] Match confidence: {match_value:.3f} (threshold: 0.4)") logger.info(f"[TM] Logo detected at center ({center_x}, {center_y}) in image {page_gray.shape[1]}x{page_gray.shape[0]}") return { 'max_val': float(match_value), 'top_left': top_left, 'center': (center_x, center_y), 'template_size': (template_w, template_h) } def extract_cma_from_roi(roi_img, ocr_engine, output_dir=None, debug_prefix=""): """ 在指定的 ROI 区域内进行 OCR 提取 CMA 码 Args: roi_img: ROI 区域图像 ocr_engine: OCR 引擎 output_dir: 输出目录 debug_prefix: 调试信息前缀 Returns: result: 提取结果字典 """ result = { 'code': None, 'confidence': 0.0, 'raw_text': '', 'position': (0, 0), 'box': None, 'success': False } if roi_img is None or roi_img.size == 0: logger.error(f"{debug_prefix}Invalid ROI image") return result h, w = roi_img.shape[:2] logger.info(f"{debug_prefix}ROI: (0, 0) -> ({w}, {h})") logger.info(f"{debug_prefix}ROI size: {w}x{h}") # 运行 OCR try: # 检查是否为 PaddleOCRVL if hasattr(ocr_engine, 'predict'): raw_result = ocr_engine.predict(roi_img) else: raw_result = ocr_engine.ocr(roi_img) if raw_result is None or len(raw_result) == 0: logger.error(f"{debug_prefix}OCR returned empty result") return result except Exception as e: logger.error(f"{debug_prefix}OCR failed: {e}") return result # 处理 OCR 结果 rec_texts = [] rec_scores = [] rec_boxes = [] # 检查结果格式 if isinstance(raw_result[0], dict): # 新 API: raw_result[0] 是 OCRResult 对象 ocr_data = raw_result[0] rec_texts = list(ocr_data.get('rec_texts', [])) rec_scores = list(ocr_data.get('rec_scores', [])) rec_boxes = list(ocr_data.get('rec_boxes', [])) logger.info(f"{debug_prefix}Using predict() API format, found {len(rec_texts)} lines") elif isinstance(raw_result[0], list): # 旧 API: raw_result[0] 是 [ [box, (text, score)], ... ] for item in raw_result[0]: if item and len(item) >= 2: box = item[0] text_info = item[1] if text_info and len(text_info) >= 2: text = text_info[0] score = text_info[1] # 计算边界框 (从4个角点) if isinstance(box, list) and len(box) >= 4: x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords) rec_boxes.append([x1, y1, x2, y2]) else: rec_boxes.append(box) rec_texts.append(text) rec_scores.append(score) logger.info(f"{debug_prefix}Using legacy ocr() API format, found {len(rec_texts)} lines") else: logger.warning(f"{debug_prefix}Unknown OCR result format: {type(raw_result[0])}") return result if not rec_texts: logger.warning(f"{debug_prefix}No text recognized in ROI") return result logger.info(f"{debug_prefix}OCR found {len(rec_texts)} text lines") # 打印所有识别的文本(调试) for i, (text, score) in enumerate(zip(rec_texts, rec_scores)): logger.info(f"{debug_prefix}Line {i}: '{text}' (score: {score:.2f})") # 提取 CMA 码候选 cma_candidates = [] for i, text in enumerate(rec_texts): if not text: continue # 提取所有数字序列(优先匹配12位,其次是11位) numbers = re.findall(r'\d{12}', str(text)) if not numbers: numbers = re.findall(r'\d{11}', str(text)) # Debug: print what we found if numbers and any('210020349' in n for n in numbers): logger.debug(f"[DEBUG] Found numbers in '{text}': {numbers}") for num in numbers: # 获取对应的边界框和分数 box = rec_boxes[i] if i < len(rec_boxes) else None score = rec_scores[i] if i < len(rec_scores) else 0.5 # 计算位置 (边界框中心) if box is not None and len(box) >= 4: position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) else: position = (0, 0) cma_candidates.append({ 'code': num, 'confidence': score, 'text': str(text), 'position': position, 'box': box, }) # 选择最佳候选 if cma_candidates: # 按分数排序(考虑位置和长度) cma_candidates.sort(key=lambda x: ( x['confidence'] * 100 + (30 if x['position'][0] > w / 3 and x['position'][1] < h / 3 else 0) # 右上角加分 + (10 if len(x['code']) == 11 else 0) - (20 if x['code'].startswith('2') else 0) ), reverse=True) best = cma_candidates[0] result['code'] = best['code'] result['confidence'] = best['confidence'] result['raw_text'] = best['text'] result['position'] = best['position'] result['box'] = best['box'] result['success'] = True logger.info(f"{debug_prefix}Best CMA candidate: {best['code']} (conf: {best['confidence']:.2f})") else: logger.warning(f"{debug_prefix}No CMA code candidates found in ROI text") # 保存可视化结果 box = result.get('box') if output_dir and result['success'] and box is not None: os.makedirs(output_dir, exist_ok=True) vis_roi = roi_img.copy() if box is not None and len(box) >= 4: # box is [x1, y1, x2, y2] format cv2.rectangle(vis_roi, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2) # 在边界框上方显示文本 text_pos = (int(box[0]), max(10, int(box[1]) - 10)) cv2.putText(vis_roi, f"CMA: {result['code']}", text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2) cv2.imwrite(os.path.join(output_dir, f"{debug_prefix.strip()}cma_roi_extraction.png"), vis_roi) logger.info(f"{debug_prefix}Saved ROI extraction visualization") return result def extract_cma_code_fullpage(page_img, ocr_engine, template_path='template/CMA_Logo.png', output_dir=None, use_template_matching=True): """ 使用模板匹配提取 CMA 码的完整流程 Args: page_img: 页面图像 ocr_engine: OCR 引擎 template_path: CMA logo 模板路径 output_dir: 输出目录 use_template_matching: 是否使用模板匹配(False则直接全页OCR) Returns: result: CMA 提取结果 """ result = { 'code': None, 'confidence': 0.0, 'raw_text': '', 'position': (0, 0), 'box': None, 'success': False, 'method': 'none' } # 加载图像 if isinstance(page_img, str): image = imread_unicode(page_img, cv2.IMREAD_COLOR) elif isinstance(page_img, np.ndarray): image = page_img else: logger.error(f"Invalid image type: {type(page_img)}") return result if image is None or image.size == 0: logger.error("Failed to load image or empty image") return result h, w = image.shape[:2] # 加载模板 if use_template_matching: template, _ = load_cma_template(template_path) if template is None: logger.warning("Cannot load template, falling back to full-page OCR") use_template_matching = False # 方法1: 模板匹配 + ROI OCR template_match_success = False if use_template_matching: logger.info("[TM] Starting template matching extraction...") match_result = match_template(image, template) if match_result is None: logger.warning("[TM] Template matching failed") else: match_value = match_result['max_val'] # 检查匹配置信度 if match_value < 0.4: logger.warning(f"[TM] Match confidence too low: {match_value:.3f}") else: # 模板匹配成功,尝试ROI提取 template_match_success = True # 确定 ROI(关键:ROI 应该在 logo 的右侧,而不是以 logo 为中心) center_x, center_y = match_result['center'] template_w, template_h = match_result['template_size'] # 修正:ROI应该在logo的右侧,因为CMA编号通常在logo右边 # 而不是以logo为中心 roi_x1 = max(0, center_x) # 从logo中心开始向右 roi_y1 = max(0, center_y - template_h // 2) # 上下与logo对齐 roi_x2 = min(w, center_x + min(600, w - center_x)) # 向右扩展最多600px roi_y2 = min(h, center_y + template_h // 2 + template_h) # 向下扩展一些 # 确保ROI在图像范围内 roi_x1 = max(roi_x1, 0) roi_y1 = max(roi_y1, 0) roi_x2 = min(w, roi_x2) roi_y2 = min(h, roi_y2) logger.info(f"[TM] ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})") roi_img = image[roi_y1:roi_y2, roi_x1:roi_x2] # 在 ROI 内提取 CMA 码 result = extract_cma_from_roi(roi_img, ocr_engine, output_dir, debug_prefix="[TM] ") if result['success']: result['method'] = 'template_matching' logger.info(f"[TM] Template matching SUCCESS: {result['code']} (conf: {result['confidence']:.2f})") return result else: logger.warning("[TM] Template matching found logo, but OCR failed to extract CMA code") # 模板匹配失败,尝试全页OCR作为fallback logger.info("[FALLBACK] Template matching failed, trying full-page OCR...") result = extract_cma_fullpage_fallback(image, ocr_engine, output_dir) result['method'] = 'fullpage_fallback' return result def extract_cma_fullpage_fallback(page_img, ocr_engine, output_dir=None): """ 全页OCR fallback方法 - 当模板匹配失败时使用 Args: page_img: 页面图像 ocr_engine: OCR 引擎 output_dir: 输出目录 Returns: result: CMA 提取结果 """ result = { 'code': None, 'confidence': 0.0, 'raw_text': '', 'position': (0, 0), 'box': None, 'success': False } if isinstance(page_img, str): image = imread_unicode(page_img, cv2.IMREAD_COLOR) elif isinstance(page_img, np.ndarray): image = page_img else: logger.error(f"Invalid image type: {type(page_img)}") return result if image is None or image.size == 0: logger.error("Failed to load image or empty image") return result h, w = image.shape[:2] # 运行全页OCR logger.info("[FALLBACK] Running full-page OCR...") try: raw_result = ocr_engine.ocr(image) except Exception as e: logger.error(f"[FALLBACK] OCR failed: {e}") return result # 处理OCR结果 rec_texts = [] rec_scores = [] rec_boxes = [] if raw_result and len(raw_result) > 0: first = raw_result[0] if isinstance(first, dict): rec_texts = list(first.get('rec_texts', [])) rec_scores = list(first.get('rec_scores', [])) rec_boxes = list(first.get('rec_boxes', [])) elif isinstance(first, list): for item in first: if item and len(item) >= 2: box = item[0] text_info = item[1] if text_info and len(text_info) >= 2: text = text_info[0] score = text_info[1] if isinstance(box, list) and len(box) >= 4: x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] x1, y1, x2, y2 = min(x_coords), min(y_coords), max(x_coords), max(y_coords) rec_boxes.append([x1, y1, x2, y2]) else: rec_boxes.append(box) rec_texts.append(text) rec_scores.append(score) logger.info(f"[FALLBACK] Found {len(rec_texts)} text lines") # 提取CMA码候选 cma_candidates = [] for i, text in enumerate(rec_texts): if not text: continue # 提取所有数字序列(优先匹配12位,其次是11位) numbers = re.findall(r'\d{12}', str(text)) if not numbers: numbers = re.findall(r'\d{11}', str(text)) for num in numbers: box = rec_boxes[i] if i < len(rec_boxes) else None score = rec_scores[i] if i < len(rec_scores) else 0.5 if box is not None and len(box) >= 4: position = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) else: position = (0, 0) cma_candidates.append({ 'code': num, 'confidence': score, 'text': str(text), 'position': position, 'box': box, }) if not cma_candidates: logger.warning("[FALLBACK] No CMA code candidates found") return result # 评分和排序(优先右上角,优先以2开头的) cma_candidates.sort(key=lambda x: ( x['confidence'] * 100 + (50 if x['code'].startswith('2') else 0) # 以2开头的优先 + (30 if x['position'][0] > w / 2 and x['position'][1] < h / 3 else 0) # 右上角加分 + (10 if len(x['code']) == 11 else 0) ), reverse=True) best = cma_candidates[0] result['code'] = best['code'] result['confidence'] = best['confidence'] result['raw_text'] = best['text'] result['position'] = best['position'] result['box'] = best['box'] result['success'] = True logger.info(f"[FALLBACK] CMA extracted: {best['code']} (conf: {best['confidence']:.2f})") return result if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='CMA Logo 模板匹配提取') parser.add_argument('--pdf', help='PDF 文件路径') parser.add_argument('--template', default='template/CMA_Logo.png', help='CMA logo 模板路径') parser.add_argument('--output', default='template_match_debug', help='输出目录') args = parser.parse_args() # 检查文件 if not os.path.exists(args.pdf): print(f"错误: PDF 文件不存在: {args.pdf}") sys.exit(1) if not os.path.exists(args.template): print(f"错误: 模板文件不存在: {args.template}") sys.exit(1) # 加载 OCR 引擎 os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" os.environ["PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK"] = "True" from paddleocr import PaddleOCR ocr_engine = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False) # 处理 PDF 的第一页 import fitz doc = fitz.open(args.pdf) page = doc[0] pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3) img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) print(f"PDF 尺寸: {pix.width}x{pix.height}") print(f"图像尺寸: {img_rgb.shape}") # 执行模板匹配提取 result = extract_cma_code_fullpage(img_rgb, ocr_engine, args.template, args.output) # 输出结果 print() print("="*80) print("CMA 提取结果:") print("-"*80) print(f" 方法: {result.get('method', 'unknown')}") print(f" CMA码: {result.get('code', 'N/A')}") print(f" 置信度: {result.get('confidence', 0.0):.2f}") print(f" 位置: {result.get('position', 'N/A')}") print("-"*80) print(f" 提取成功: {result.get('success', False)}") print("="*80)