""" 改进的CMA码提取测试 - 结合方案2和方案3 方案2: 智能fallback机制 - 当模板匹配失效时自动使用全页OCR 方案3: 调整模板匹配参数 - 添加预处理、多尺度、多方法尝试 """ import sys import os import cv2 import numpy as np import fitz import re import logging from pathlib import Path from typing import Dict, List, Optional, Tuple os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" from paddleocr import PaddleOCR # ============ 配置 ============ # 测试PDF TEST_PDF = "src/test/resources/data/pdfs/YDQ23_001838.pdf" TEMPLATE_PATH = "template/CMA_Logo.png" OUTPUT_DIR = Path("test_improved_extraction") OUTPUT_DIR.mkdir(exist_ok=True) # 日志配置 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler(OUTPUT_DIR / "test.log", encoding='utf-8') ] ) logger = logging.getLogger(__name__) # ============ 方案3: 改进的模板匹配 ============ class ImprovedTemplateMatcher: """改进的模板匹配器 - 结合多种方法和预处理""" def __init__(self, template_path: str): self.template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE) if self.template is None: raise ValueError(f"Cannot load template from {template_path}") self.template_h, self.template_w = self.template.shape[:2] logger.info(f"Template loaded: {self.template_w}x{self.template_h}") def preprocess_page(self, page_img: np.ndarray) -> Dict[str, np.ndarray]: """预处理页面图像,生成多个版本用于匹配""" gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY) if len(page_img.shape) == 3 else page_img processed = { 'original': gray, 'blurred': cv2.GaussianBlur(gray, (5, 5), 0), 'denoised': cv2.fastNlMeansDenoising(gray, None, 10, 7, 21), 'equalized': cv2.equalizeHist(gray), 'clahe': cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray), } # 添加边缘增强版本(对圆形标志有帮助) edges = cv2.Canny(gray, 50, 150) processed['edges'] = edges logger.info(f"Generated {len(processed)} preprocessed versions") return processed def match_multi_method( self, page_img: np.ndarray, scales: List[float] = [0.8, 0.9, 1.0, 1.1, 1.2], methods: List[int] = [cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF] ) -> Dict: """ 使用多种方法和尺度进行模板匹配 Returns: { 'success': bool, 'best_match': {'confidence': float, 'location': tuple, 'method': str, 'scale': float, 'preprocessing': str}, 'all_matches': List[Dict], 'num_matches': int } """ h, w = page_img.shape[:2] max_y_threshold = int(h * 0.6) # 只接受页面上半部分的匹配 # 预处理页面 preprocessed = self.preprocess_page(page_img) all_matches = [] num_total_checks = 0 for prep_name, processed_img in preprocessed.items(): for scale in scales: # 调整模板大小 if scale != 1.0: new_w = int(self.template_w * scale) new_h = int(self.template_h * scale) if new_w < 10 or new_h < 10: continue scaled_template = cv2.resize(self.template, (new_w, new_h), interpolation=cv2.INTER_AREA) else: scaled_template = self.template new_h, new_w = self.template_h, self.template_w for method in methods: num_total_checks += 1 try: result = cv2.matchTemplate(processed_img, scaled_template, method) min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result) # 计算匹配中心位置 match_center_y = max_loc[1] + new_h // 2 # 位置过滤:只接受页面上半部分的匹配 if match_center_y > max_y_threshold: continue match_info = { 'confidence': float(max_val), 'location': max_loc, 'center': (max_loc[0] + new_w // 2, max_loc[1] + new_h // 2), 'method': method, 'scale': scale, 'preprocessing': prep_name, 'template_size': (new_w, new_h) } all_matches.append(match_info) except Exception as e: logger.debug(f"Match failed: prep={prep_name}, scale={scale}, method={method}, error={e}") continue logger.info(f"Total match attempts: {num_total_checks}") logger.info(f"Valid matches (above threshold, in upper 60%): {len(all_matches)}") if not all_matches: return { 'success': False, 'reason': 'No valid matches found', 'num_matches': 0 } # 按置信度排序 all_matches.sort(key=lambda x: x['confidence'], reverse=True) # 统计每个位置附近的匹配数量(用于检测匹配失效) best_match = all_matches[0] match_positions = [(m['center'][0], m['center'][1]) for m in all_matches[:10]] # 检查是否有过多匹配(可能意味着模板匹配失效) if len(all_matches) > 1000: logger.warning(f"Too many matches ({len(all_matches)}), template matching may have failed") return { 'success': True, 'best_match': best_match, 'all_matches': all_matches, 'num_matches': len(all_matches) } def is_matching_failed(self, match_result: Dict) -> bool: """ 判断模板匹配是否失效 失效的迹象: 1. 匹配数量过多(>1000)- 说明模板匹配了太多地方 2. 所有匹配的置信度都很高且接近 - 说明可能是噪声 3. 匹配位置分散在整个页面 """ if not match_result.get('success'): return True num_matches = match_result.get('num_matches', 0) best_confidence = match_result['best_match']['confidence'] # 检查1: 匹配数量过多 if num_matches > 1000: logger.warning(f"Template matching failed: {num_matches} matches (threshold: >1000)") return True # 检查2: 置信度异常高且匹配数量多 if num_matches > 100 and best_confidence > 0.9: logger.warning(f"Template matching failed: high confidence ({best_confidence:.3f}) with many matches ({num_matches})") return True return False # ============ 方案2: 智能Fallback提取器 ============ class SmartCMAExtractor: """智能CMA码提取器 - 结合模板匹配和全页OCR""" def __init__(self, ocr_engine: PaddleOCR): self.ocr = ocr_engine self.matcher = ImprovedTemplateMatcher(TEMPLATE_PATH) def extract(self, page_img: np.ndarray, pdf_name: str) -> Dict: """ 智能提取CMA码: 1. 尝试改进的模板匹配 2. 检测匹配是否失效 3. 如果失效,使用全页OCR fallback """ result = { 'pdf_name': pdf_name, 'success': False, 'code': None, 'confidence': 0.0, 'method': None, 'match_result': None } logger.info(f"\n{'='*80}") logger.info(f"EXTRACTING FROM: {pdf_name}") logger.info(f"{'='*80}") # 步骤1: 尝试改进的模板匹配 logger.info("\n[Step 1] Attempting improved template matching...") match_result = self.matcher.match_multi_method(page_img) if match_result['success']: best_match = match_result['best_match'] logger.info(f"Template match found:") logger.info(f" Confidence: {best_match['confidence']:.3f}") logger.info(f" Location: {best_match['center']}") logger.info(f" Method: {best_match['method']}") logger.info(f" Scale: {best_match['scale']}") logger.info(f" Preprocessing: {best_match['preprocessing']}") logger.info(f" Total matches: {match_result['num_matches']}") result['match_result'] = match_result # 检查匹配是否失效 if self.matcher.is_matching_failed(match_result): logger.warning("⚠️ Template matching FAILED - using full-page OCR fallback") result['method'] = 'fullpage_fallback' return self._extract_fullpage(page_img, result) else: logger.info("✓ Template matching appears valid, extracting from ROI...") return self._extract_from_roi(page_img, best_match, result) else: logger.warning(f"⚠️ No template match found - reason: {match_result.get('reason')}") logger.info("→ Using full-page OCR fallback") result['method'] = 'fullpage_fallback' return self._extract_fullpage(page_img, result) def _extract_from_roi(self, page_img: np.ndarray, match_info: Dict, result: Dict) -> Dict: """从ROI区域提取CMA码""" # 计算ROI(logo右侧) x, y = match_info['center'] template_w, template_h = match_info['template_size'] h, w = page_img.shape[:2] # ROI: logo右侧,向下延伸 roi_x1 = max(0, x) roi_y1 = max(0, y - template_h // 2) roi_x2 = min(w, x + min(600, w - x)) roi_y2 = min(h, y + template_h * 4) logger.info(f"ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})") logger.info(f"ROI size: {roi_x2 - roi_x1}x{roi_y2 - roi_y1}") roi_img = page_img[roi_y1:roi_y2, roi_x1:roi_x2] # 保存ROI cv2.imwrite(str(OUTPUT_DIR / "roi.png"), roi_img) # OCR提取 cma_code = self._extract_cma_from_ocr_result(roi_img) if cma_code: result['success'] = True result['code'] = cma_code['code'] result['confidence'] = cma_code['confidence'] result['method'] = 'template_matching' logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})") else: logger.warning("ROI extraction failed, trying full-page OCR fallback...") return self._extract_fullpage(page_img, result) return result def _extract_fullpage(self, page_img: np.ndarray, result: Dict) -> Dict: """全页OCR fallback""" logger.info("\n[Step 2] Running full-page OCR fallback...") cma_code = self._extract_cma_from_ocr_result(page_img) if cma_code: result['success'] = True result['code'] = cma_code['code'] result['confidence'] = cma_code['confidence'] result['method'] = 'fullpage_ocr' logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})") else: result['method'] = 'failed' logger.error("✗ FAILED: Full-page OCR also failed") return result def _extract_cma_from_ocr_result(self, img: np.ndarray) -> Optional[Dict]: """从OCR结果中提取CMA码""" try: ocr_result = self.ocr.predict(img) if not ocr_result or len(ocr_result) == 0: logger.warning("OCR returned no results") return None res = ocr_result[0] texts = res.get('rec_texts', []) scores = res.get('rec_scores', []) logger.info(f"OCR found {len(texts)} text lines") # 查找所有11-12位数字 pattern = re.compile(r'\d{11,12}') candidates = [] for i, (text, score) in enumerate(zip(texts, scores)): matches = pattern.findall(text.replace(" ", "").replace("-", "")) for num in matches: candidates.append({ 'code': num, 'confidence': float(score), 'text': text, 'line': i }) if not candidates: logger.warning("No 11-12 digit numbers found in OCR results") return None # 优先选择以"2"开头的候选(CMA码标准格式) candidates_starting_with_2 = [c for c in candidates if c['code'].startswith('2')] if candidates_starting_with_2: candidates_starting_with_2.sort(key=lambda x: x['confidence'], reverse=True) best = candidates_starting_with_2[0] logger.info(f"Best candidate (starts with '2'): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})") return best else: candidates.sort(key=lambda x: x['confidence'], reverse=True) best = candidates[0] logger.info(f"Best candidate (no '2' prefix): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})") return best except Exception as e: logger.error(f"OCR extraction failed: {e}") return None # ============ 测试函数 ============ def test_single_pdf(pdf_path: str, expected_cma: str = None): """测试单个PDF的CMA码提取""" logger.info(f"\n{'#'*80}") logger.info(f"TESTING: {Path(pdf_path).name}") logger.info(f"Expected CMA: {expected_cma or 'Unknown'}") logger.info(f"{'#'*80}\n") # 提取页面 logger.info("Extracting PDF page...") doc = fitz.open(pdf_path) page = doc[0] # 使用300 DPI渲染 mat = fitz.Matrix(300 / 72, 300 / 72) pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("png") img_array = np.frombuffer(img_data, dtype=np.uint8) page_img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) doc.close() logger.info(f"Page size: {page_img.shape}") # 初始化OCR logger.info("Initializing PaddleOCR...") ocr = PaddleOCR(lang='ch') # 提取CMA码 extractor = SmartCMAExtractor(ocr) result = extractor.extract(page_img, Path(pdf_path).name) # 输出结果 logger.info("\n" + "="*80) logger.info("FINAL RESULT") logger.info("="*80) logger.info(f"PDF: {result['pdf_name']}") logger.info(f"Success: {result['success']}") logger.info(f"Method: {result['method']}") logger.info(f"CMA Code: {result.get('code', 'N/A')}") logger.info(f"Confidence: {result.get('confidence', 0):.2f}") if expected_cma: if result['code'] == expected_cma: logger.info(f"✓✓✓ CORRECT! Expected: {expected_cma}, Got: {result['code']}") else: logger.info(f"✗✗✗ WRONG! Expected: {expected_cma}, Got: {result['code']}") logger.info("="*80 + "\n") return result # ============ 主程序 ============ if __name__ == "__main__": # 测试YDQ23_001838.pdf test_single_pdf(TEST_PDF, expected_cma="210020349096") print("\n" + "="*80) print("TEST COMPLETED") print("="*80) print(f"Results saved to: {OUTPUT_DIR}") print(f" - test.log: Detailed log") print(f" - roi.png: ROI image (if template matching succeeded)")