425 lines
16 KiB
Python
425 lines
16 KiB
Python
"""
|
||
改进的CMA码提取测试 - 结合方案2和方案3
|
||
|
||
方案2: 智能fallback机制 - 当模板匹配失效时自动使用全页OCR
|
||
方案3: 调整模板匹配参数 - 添加预处理、多尺度、多方法尝试
|
||
"""
|
||
import sys
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import fitz
|
||
import re
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
||
|
||
from paddleocr import PaddleOCR
|
||
|
||
# ============ 配置 ============
|
||
|
||
# 测试PDF
|
||
TEST_PDF = "src/test/resources/data/pdfs/YDQ23_001838.pdf"
|
||
TEMPLATE_PATH = "template/CMA_Logo.png"
|
||
OUTPUT_DIR = Path("test_improved_extraction")
|
||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||
|
||
# 日志配置
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler(OUTPUT_DIR / "test.log", encoding='utf-8')
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ============ 方案3: 改进的模板匹配 ============
|
||
|
||
class ImprovedTemplateMatcher:
|
||
"""改进的模板匹配器 - 结合多种方法和预处理"""
|
||
|
||
def __init__(self, template_path: str):
|
||
self.template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
|
||
if self.template is None:
|
||
raise ValueError(f"Cannot load template from {template_path}")
|
||
|
||
self.template_h, self.template_w = self.template.shape[:2]
|
||
logger.info(f"Template loaded: {self.template_w}x{self.template_h}")
|
||
|
||
def preprocess_page(self, page_img: np.ndarray) -> Dict[str, np.ndarray]:
|
||
"""预处理页面图像,生成多个版本用于匹配"""
|
||
gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY) if len(page_img.shape) == 3 else page_img
|
||
|
||
processed = {
|
||
'original': gray,
|
||
'blurred': cv2.GaussianBlur(gray, (5, 5), 0),
|
||
'denoised': cv2.fastNlMeansDenoising(gray, None, 10, 7, 21),
|
||
'equalized': cv2.equalizeHist(gray),
|
||
'clahe': cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray),
|
||
}
|
||
|
||
# 添加边缘增强版本(对圆形标志有帮助)
|
||
edges = cv2.Canny(gray, 50, 150)
|
||
processed['edges'] = edges
|
||
|
||
logger.info(f"Generated {len(processed)} preprocessed versions")
|
||
return processed
|
||
|
||
def match_multi_method(
|
||
self,
|
||
page_img: np.ndarray,
|
||
scales: List[float] = [0.8, 0.9, 1.0, 1.1, 1.2],
|
||
methods: List[int] = [cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF]
|
||
) -> Dict:
|
||
"""
|
||
使用多种方法和尺度进行模板匹配
|
||
|
||
Returns:
|
||
{
|
||
'success': bool,
|
||
'best_match': {'confidence': float, 'location': tuple, 'method': str, 'scale': float, 'preprocessing': str},
|
||
'all_matches': List[Dict],
|
||
'num_matches': int
|
||
}
|
||
"""
|
||
h, w = page_img.shape[:2]
|
||
max_y_threshold = int(h * 0.6) # 只接受页面上半部分的匹配
|
||
|
||
# 预处理页面
|
||
preprocessed = self.preprocess_page(page_img)
|
||
|
||
all_matches = []
|
||
num_total_checks = 0
|
||
|
||
for prep_name, processed_img in preprocessed.items():
|
||
for scale in scales:
|
||
# 调整模板大小
|
||
if scale != 1.0:
|
||
new_w = int(self.template_w * scale)
|
||
new_h = int(self.template_h * scale)
|
||
if new_w < 10 or new_h < 10:
|
||
continue
|
||
scaled_template = cv2.resize(self.template, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||
else:
|
||
scaled_template = self.template
|
||
new_h, new_w = self.template_h, self.template_w
|
||
|
||
for method in methods:
|
||
num_total_checks += 1
|
||
|
||
try:
|
||
result = cv2.matchTemplate(processed_img, scaled_template, method)
|
||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
|
||
|
||
# 计算匹配中心位置
|
||
match_center_y = max_loc[1] + new_h // 2
|
||
|
||
# 位置过滤:只接受页面上半部分的匹配
|
||
if match_center_y > max_y_threshold:
|
||
continue
|
||
|
||
match_info = {
|
||
'confidence': float(max_val),
|
||
'location': max_loc,
|
||
'center': (max_loc[0] + new_w // 2, max_loc[1] + new_h // 2),
|
||
'method': method,
|
||
'scale': scale,
|
||
'preprocessing': prep_name,
|
||
'template_size': (new_w, new_h)
|
||
}
|
||
|
||
all_matches.append(match_info)
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Match failed: prep={prep_name}, scale={scale}, method={method}, error={e}")
|
||
continue
|
||
|
||
logger.info(f"Total match attempts: {num_total_checks}")
|
||
logger.info(f"Valid matches (above threshold, in upper 60%): {len(all_matches)}")
|
||
|
||
if not all_matches:
|
||
return {
|
||
'success': False,
|
||
'reason': 'No valid matches found',
|
||
'num_matches': 0
|
||
}
|
||
|
||
# 按置信度排序
|
||
all_matches.sort(key=lambda x: x['confidence'], reverse=True)
|
||
|
||
# 统计每个位置附近的匹配数量(用于检测匹配失效)
|
||
best_match = all_matches[0]
|
||
match_positions = [(m['center'][0], m['center'][1]) for m in all_matches[:10]]
|
||
|
||
# 检查是否有过多匹配(可能意味着模板匹配失效)
|
||
if len(all_matches) > 1000:
|
||
logger.warning(f"Too many matches ({len(all_matches)}), template matching may have failed")
|
||
|
||
return {
|
||
'success': True,
|
||
'best_match': best_match,
|
||
'all_matches': all_matches,
|
||
'num_matches': len(all_matches)
|
||
}
|
||
|
||
def is_matching_failed(self, match_result: Dict) -> bool:
|
||
"""
|
||
判断模板匹配是否失效
|
||
|
||
失效的迹象:
|
||
1. 匹配数量过多(>1000)- 说明模板匹配了太多地方
|
||
2. 所有匹配的置信度都很高且接近 - 说明可能是噪声
|
||
3. 匹配位置分散在整个页面
|
||
"""
|
||
if not match_result.get('success'):
|
||
return True
|
||
|
||
num_matches = match_result.get('num_matches', 0)
|
||
best_confidence = match_result['best_match']['confidence']
|
||
|
||
# 检查1: 匹配数量过多
|
||
if num_matches > 1000:
|
||
logger.warning(f"Template matching failed: {num_matches} matches (threshold: >1000)")
|
||
return True
|
||
|
||
# 检查2: 置信度异常高且匹配数量多
|
||
if num_matches > 100 and best_confidence > 0.9:
|
||
logger.warning(f"Template matching failed: high confidence ({best_confidence:.3f}) with many matches ({num_matches})")
|
||
return True
|
||
|
||
return False
|
||
|
||
# ============ 方案2: 智能Fallback提取器 ============
|
||
|
||
class SmartCMAExtractor:
|
||
"""智能CMA码提取器 - 结合模板匹配和全页OCR"""
|
||
|
||
def __init__(self, ocr_engine: PaddleOCR):
|
||
self.ocr = ocr_engine
|
||
self.matcher = ImprovedTemplateMatcher(TEMPLATE_PATH)
|
||
|
||
def extract(self, page_img: np.ndarray, pdf_name: str) -> Dict:
|
||
"""
|
||
智能提取CMA码:
|
||
1. 尝试改进的模板匹配
|
||
2. 检测匹配是否失效
|
||
3. 如果失效,使用全页OCR fallback
|
||
"""
|
||
result = {
|
||
'pdf_name': pdf_name,
|
||
'success': False,
|
||
'code': None,
|
||
'confidence': 0.0,
|
||
'method': None,
|
||
'match_result': None
|
||
}
|
||
|
||
logger.info(f"\n{'='*80}")
|
||
logger.info(f"EXTRACTING FROM: {pdf_name}")
|
||
logger.info(f"{'='*80}")
|
||
|
||
# 步骤1: 尝试改进的模板匹配
|
||
logger.info("\n[Step 1] Attempting improved template matching...")
|
||
match_result = self.matcher.match_multi_method(page_img)
|
||
|
||
if match_result['success']:
|
||
best_match = match_result['best_match']
|
||
|
||
logger.info(f"Template match found:")
|
||
logger.info(f" Confidence: {best_match['confidence']:.3f}")
|
||
logger.info(f" Location: {best_match['center']}")
|
||
logger.info(f" Method: {best_match['method']}")
|
||
logger.info(f" Scale: {best_match['scale']}")
|
||
logger.info(f" Preprocessing: {best_match['preprocessing']}")
|
||
logger.info(f" Total matches: {match_result['num_matches']}")
|
||
|
||
result['match_result'] = match_result
|
||
|
||
# 检查匹配是否失效
|
||
if self.matcher.is_matching_failed(match_result):
|
||
logger.warning("⚠️ Template matching FAILED - using full-page OCR fallback")
|
||
result['method'] = 'fullpage_fallback'
|
||
return self._extract_fullpage(page_img, result)
|
||
else:
|
||
logger.info("✓ Template matching appears valid, extracting from ROI...")
|
||
return self._extract_from_roi(page_img, best_match, result)
|
||
else:
|
||
logger.warning(f"⚠️ No template match found - reason: {match_result.get('reason')}")
|
||
logger.info("→ Using full-page OCR fallback")
|
||
result['method'] = 'fullpage_fallback'
|
||
return self._extract_fullpage(page_img, result)
|
||
|
||
def _extract_from_roi(self, page_img: np.ndarray, match_info: Dict, result: Dict) -> Dict:
|
||
"""从ROI区域提取CMA码"""
|
||
# 计算ROI(logo右侧)
|
||
x, y = match_info['center']
|
||
template_w, template_h = match_info['template_size']
|
||
h, w = page_img.shape[:2]
|
||
|
||
# ROI: logo右侧,向下延伸
|
||
roi_x1 = max(0, x)
|
||
roi_y1 = max(0, y - template_h // 2)
|
||
roi_x2 = min(w, x + min(600, w - x))
|
||
roi_y2 = min(h, y + template_h * 4)
|
||
|
||
logger.info(f"ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})")
|
||
logger.info(f"ROI size: {roi_x2 - roi_x1}x{roi_y2 - roi_y1}")
|
||
|
||
roi_img = page_img[roi_y1:roi_y2, roi_x1:roi_x2]
|
||
|
||
# 保存ROI
|
||
cv2.imwrite(str(OUTPUT_DIR / "roi.png"), roi_img)
|
||
|
||
# OCR提取
|
||
cma_code = self._extract_cma_from_ocr_result(roi_img)
|
||
|
||
if cma_code:
|
||
result['success'] = True
|
||
result['code'] = cma_code['code']
|
||
result['confidence'] = cma_code['confidence']
|
||
result['method'] = 'template_matching'
|
||
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
|
||
else:
|
||
logger.warning("ROI extraction failed, trying full-page OCR fallback...")
|
||
return self._extract_fullpage(page_img, result)
|
||
|
||
return result
|
||
|
||
def _extract_fullpage(self, page_img: np.ndarray, result: Dict) -> Dict:
|
||
"""全页OCR fallback"""
|
||
logger.info("\n[Step 2] Running full-page OCR fallback...")
|
||
|
||
cma_code = self._extract_cma_from_ocr_result(page_img)
|
||
|
||
if cma_code:
|
||
result['success'] = True
|
||
result['code'] = cma_code['code']
|
||
result['confidence'] = cma_code['confidence']
|
||
result['method'] = 'fullpage_ocr'
|
||
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
|
||
else:
|
||
result['method'] = 'failed'
|
||
logger.error("✗ FAILED: Full-page OCR also failed")
|
||
|
||
return result
|
||
|
||
def _extract_cma_from_ocr_result(self, img: np.ndarray) -> Optional[Dict]:
|
||
"""从OCR结果中提取CMA码"""
|
||
try:
|
||
ocr_result = self.ocr.predict(img)
|
||
|
||
if not ocr_result or len(ocr_result) == 0:
|
||
logger.warning("OCR returned no results")
|
||
return None
|
||
|
||
res = ocr_result[0]
|
||
texts = res.get('rec_texts', [])
|
||
scores = res.get('rec_scores', [])
|
||
|
||
logger.info(f"OCR found {len(texts)} text lines")
|
||
|
||
# 查找所有11-12位数字
|
||
pattern = re.compile(r'\d{11,12}')
|
||
candidates = []
|
||
|
||
for i, (text, score) in enumerate(zip(texts, scores)):
|
||
matches = pattern.findall(text.replace(" ", "").replace("-", ""))
|
||
for num in matches:
|
||
candidates.append({
|
||
'code': num,
|
||
'confidence': float(score),
|
||
'text': text,
|
||
'line': i
|
||
})
|
||
|
||
if not candidates:
|
||
logger.warning("No 11-12 digit numbers found in OCR results")
|
||
return None
|
||
|
||
# 优先选择以"2"开头的候选(CMA码标准格式)
|
||
candidates_starting_with_2 = [c for c in candidates if c['code'].startswith('2')]
|
||
|
||
if candidates_starting_with_2:
|
||
candidates_starting_with_2.sort(key=lambda x: x['confidence'], reverse=True)
|
||
best = candidates_starting_with_2[0]
|
||
logger.info(f"Best candidate (starts with '2'): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
|
||
return best
|
||
else:
|
||
candidates.sort(key=lambda x: x['confidence'], reverse=True)
|
||
best = candidates[0]
|
||
logger.info(f"Best candidate (no '2' prefix): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
|
||
return best
|
||
|
||
except Exception as e:
|
||
logger.error(f"OCR extraction failed: {e}")
|
||
return None
|
||
|
||
# ============ 测试函数 ============
|
||
|
||
def test_single_pdf(pdf_path: str, expected_cma: str = None):
|
||
"""测试单个PDF的CMA码提取"""
|
||
logger.info(f"\n{'#'*80}")
|
||
logger.info(f"TESTING: {Path(pdf_path).name}")
|
||
logger.info(f"Expected CMA: {expected_cma or 'Unknown'}")
|
||
logger.info(f"{'#'*80}\n")
|
||
|
||
# 提取页面
|
||
logger.info("Extracting PDF page...")
|
||
doc = fitz.open(pdf_path)
|
||
page = doc[0]
|
||
|
||
# 使用300 DPI渲染
|
||
mat = fitz.Matrix(300 / 72, 300 / 72)
|
||
pix = page.get_pixmap(matrix=mat)
|
||
img_data = pix.tobytes("png")
|
||
img_array = np.frombuffer(img_data, dtype=np.uint8)
|
||
page_img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||
doc.close()
|
||
|
||
logger.info(f"Page size: {page_img.shape}")
|
||
|
||
# 初始化OCR
|
||
logger.info("Initializing PaddleOCR...")
|
||
ocr = PaddleOCR(lang='ch')
|
||
|
||
# 提取CMA码
|
||
extractor = SmartCMAExtractor(ocr)
|
||
result = extractor.extract(page_img, Path(pdf_path).name)
|
||
|
||
# 输出结果
|
||
logger.info("\n" + "="*80)
|
||
logger.info("FINAL RESULT")
|
||
logger.info("="*80)
|
||
logger.info(f"PDF: {result['pdf_name']}")
|
||
logger.info(f"Success: {result['success']}")
|
||
logger.info(f"Method: {result['method']}")
|
||
logger.info(f"CMA Code: {result.get('code', 'N/A')}")
|
||
logger.info(f"Confidence: {result.get('confidence', 0):.2f}")
|
||
|
||
if expected_cma:
|
||
if result['code'] == expected_cma:
|
||
logger.info(f"✓✓✓ CORRECT! Expected: {expected_cma}, Got: {result['code']}")
|
||
else:
|
||
logger.info(f"✗✗✗ WRONG! Expected: {expected_cma}, Got: {result['code']}")
|
||
|
||
logger.info("="*80 + "\n")
|
||
|
||
return result
|
||
|
||
# ============ 主程序 ============
|
||
|
||
if __name__ == "__main__":
|
||
# 测试YDQ23_001838.pdf
|
||
test_single_pdf(TEST_PDF, expected_cma="210020349096")
|
||
|
||
print("\n" + "="*80)
|
||
print("TEST COMPLETED")
|
||
print("="*80)
|
||
print(f"Results saved to: {OUTPUT_DIR}")
|
||
print(f" - test.log: Detailed log")
|
||
print(f" - roi.png: ROI image (if template matching succeeded)")
|