report-detect/archive/temp_scripts/test_improved_extraction.py

425 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
改进的CMA码提取测试 - 结合方案2和方案3
方案2: 智能fallback机制 - 当模板匹配失效时自动使用全页OCR
方案3: 调整模板匹配参数 - 添加预处理、多尺度、多方法尝试
"""
import sys
import os
import cv2
import numpy as np
import fitz
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
from paddleocr import PaddleOCR
# ============ 配置 ============
# 测试PDF
TEST_PDF = "src/test/resources/data/pdfs/YDQ23_001838.pdf"
TEMPLATE_PATH = "template/CMA_Logo.png"
OUTPUT_DIR = Path("test_improved_extraction")
OUTPUT_DIR.mkdir(exist_ok=True)
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(OUTPUT_DIR / "test.log", encoding='utf-8')
]
)
logger = logging.getLogger(__name__)
# ============ 方案3: 改进的模板匹配 ============
class ImprovedTemplateMatcher:
"""改进的模板匹配器 - 结合多种方法和预处理"""
def __init__(self, template_path: str):
self.template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
if self.template is None:
raise ValueError(f"Cannot load template from {template_path}")
self.template_h, self.template_w = self.template.shape[:2]
logger.info(f"Template loaded: {self.template_w}x{self.template_h}")
def preprocess_page(self, page_img: np.ndarray) -> Dict[str, np.ndarray]:
"""预处理页面图像,生成多个版本用于匹配"""
gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY) if len(page_img.shape) == 3 else page_img
processed = {
'original': gray,
'blurred': cv2.GaussianBlur(gray, (5, 5), 0),
'denoised': cv2.fastNlMeansDenoising(gray, None, 10, 7, 21),
'equalized': cv2.equalizeHist(gray),
'clahe': cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray),
}
# 添加边缘增强版本(对圆形标志有帮助)
edges = cv2.Canny(gray, 50, 150)
processed['edges'] = edges
logger.info(f"Generated {len(processed)} preprocessed versions")
return processed
def match_multi_method(
self,
page_img: np.ndarray,
scales: List[float] = [0.8, 0.9, 1.0, 1.1, 1.2],
methods: List[int] = [cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF]
) -> Dict:
"""
使用多种方法和尺度进行模板匹配
Returns:
{
'success': bool,
'best_match': {'confidence': float, 'location': tuple, 'method': str, 'scale': float, 'preprocessing': str},
'all_matches': List[Dict],
'num_matches': int
}
"""
h, w = page_img.shape[:2]
max_y_threshold = int(h * 0.6) # 只接受页面上半部分的匹配
# 预处理页面
preprocessed = self.preprocess_page(page_img)
all_matches = []
num_total_checks = 0
for prep_name, processed_img in preprocessed.items():
for scale in scales:
# 调整模板大小
if scale != 1.0:
new_w = int(self.template_w * scale)
new_h = int(self.template_h * scale)
if new_w < 10 or new_h < 10:
continue
scaled_template = cv2.resize(self.template, (new_w, new_h), interpolation=cv2.INTER_AREA)
else:
scaled_template = self.template
new_h, new_w = self.template_h, self.template_w
for method in methods:
num_total_checks += 1
try:
result = cv2.matchTemplate(processed_img, scaled_template, method)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
# 计算匹配中心位置
match_center_y = max_loc[1] + new_h // 2
# 位置过滤:只接受页面上半部分的匹配
if match_center_y > max_y_threshold:
continue
match_info = {
'confidence': float(max_val),
'location': max_loc,
'center': (max_loc[0] + new_w // 2, max_loc[1] + new_h // 2),
'method': method,
'scale': scale,
'preprocessing': prep_name,
'template_size': (new_w, new_h)
}
all_matches.append(match_info)
except Exception as e:
logger.debug(f"Match failed: prep={prep_name}, scale={scale}, method={method}, error={e}")
continue
logger.info(f"Total match attempts: {num_total_checks}")
logger.info(f"Valid matches (above threshold, in upper 60%): {len(all_matches)}")
if not all_matches:
return {
'success': False,
'reason': 'No valid matches found',
'num_matches': 0
}
# 按置信度排序
all_matches.sort(key=lambda x: x['confidence'], reverse=True)
# 统计每个位置附近的匹配数量(用于检测匹配失效)
best_match = all_matches[0]
match_positions = [(m['center'][0], m['center'][1]) for m in all_matches[:10]]
# 检查是否有过多匹配(可能意味着模板匹配失效)
if len(all_matches) > 1000:
logger.warning(f"Too many matches ({len(all_matches)}), template matching may have failed")
return {
'success': True,
'best_match': best_match,
'all_matches': all_matches,
'num_matches': len(all_matches)
}
def is_matching_failed(self, match_result: Dict) -> bool:
"""
判断模板匹配是否失效
失效的迹象:
1. 匹配数量过多(>1000- 说明模板匹配了太多地方
2. 所有匹配的置信度都很高且接近 - 说明可能是噪声
3. 匹配位置分散在整个页面
"""
if not match_result.get('success'):
return True
num_matches = match_result.get('num_matches', 0)
best_confidence = match_result['best_match']['confidence']
# 检查1: 匹配数量过多
if num_matches > 1000:
logger.warning(f"Template matching failed: {num_matches} matches (threshold: >1000)")
return True
# 检查2: 置信度异常高且匹配数量多
if num_matches > 100 and best_confidence > 0.9:
logger.warning(f"Template matching failed: high confidence ({best_confidence:.3f}) with many matches ({num_matches})")
return True
return False
# ============ 方案2: 智能Fallback提取器 ============
class SmartCMAExtractor:
"""智能CMA码提取器 - 结合模板匹配和全页OCR"""
def __init__(self, ocr_engine: PaddleOCR):
self.ocr = ocr_engine
self.matcher = ImprovedTemplateMatcher(TEMPLATE_PATH)
def extract(self, page_img: np.ndarray, pdf_name: str) -> Dict:
"""
智能提取CMA码
1. 尝试改进的模板匹配
2. 检测匹配是否失效
3. 如果失效使用全页OCR fallback
"""
result = {
'pdf_name': pdf_name,
'success': False,
'code': None,
'confidence': 0.0,
'method': None,
'match_result': None
}
logger.info(f"\n{'='*80}")
logger.info(f"EXTRACTING FROM: {pdf_name}")
logger.info(f"{'='*80}")
# 步骤1: 尝试改进的模板匹配
logger.info("\n[Step 1] Attempting improved template matching...")
match_result = self.matcher.match_multi_method(page_img)
if match_result['success']:
best_match = match_result['best_match']
logger.info(f"Template match found:")
logger.info(f" Confidence: {best_match['confidence']:.3f}")
logger.info(f" Location: {best_match['center']}")
logger.info(f" Method: {best_match['method']}")
logger.info(f" Scale: {best_match['scale']}")
logger.info(f" Preprocessing: {best_match['preprocessing']}")
logger.info(f" Total matches: {match_result['num_matches']}")
result['match_result'] = match_result
# 检查匹配是否失效
if self.matcher.is_matching_failed(match_result):
logger.warning("⚠️ Template matching FAILED - using full-page OCR fallback")
result['method'] = 'fullpage_fallback'
return self._extract_fullpage(page_img, result)
else:
logger.info("✓ Template matching appears valid, extracting from ROI...")
return self._extract_from_roi(page_img, best_match, result)
else:
logger.warning(f"⚠️ No template match found - reason: {match_result.get('reason')}")
logger.info("→ Using full-page OCR fallback")
result['method'] = 'fullpage_fallback'
return self._extract_fullpage(page_img, result)
def _extract_from_roi(self, page_img: np.ndarray, match_info: Dict, result: Dict) -> Dict:
"""从ROI区域提取CMA码"""
# 计算ROIlogo右侧
x, y = match_info['center']
template_w, template_h = match_info['template_size']
h, w = page_img.shape[:2]
# ROI: logo右侧向下延伸
roi_x1 = max(0, x)
roi_y1 = max(0, y - template_h // 2)
roi_x2 = min(w, x + min(600, w - x))
roi_y2 = min(h, y + template_h * 4)
logger.info(f"ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})")
logger.info(f"ROI size: {roi_x2 - roi_x1}x{roi_y2 - roi_y1}")
roi_img = page_img[roi_y1:roi_y2, roi_x1:roi_x2]
# 保存ROI
cv2.imwrite(str(OUTPUT_DIR / "roi.png"), roi_img)
# OCR提取
cma_code = self._extract_cma_from_ocr_result(roi_img)
if cma_code:
result['success'] = True
result['code'] = cma_code['code']
result['confidence'] = cma_code['confidence']
result['method'] = 'template_matching'
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
else:
logger.warning("ROI extraction failed, trying full-page OCR fallback...")
return self._extract_fullpage(page_img, result)
return result
def _extract_fullpage(self, page_img: np.ndarray, result: Dict) -> Dict:
"""全页OCR fallback"""
logger.info("\n[Step 2] Running full-page OCR fallback...")
cma_code = self._extract_cma_from_ocr_result(page_img)
if cma_code:
result['success'] = True
result['code'] = cma_code['code']
result['confidence'] = cma_code['confidence']
result['method'] = 'fullpage_ocr'
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
else:
result['method'] = 'failed'
logger.error("✗ FAILED: Full-page OCR also failed")
return result
def _extract_cma_from_ocr_result(self, img: np.ndarray) -> Optional[Dict]:
"""从OCR结果中提取CMA码"""
try:
ocr_result = self.ocr.predict(img)
if not ocr_result or len(ocr_result) == 0:
logger.warning("OCR returned no results")
return None
res = ocr_result[0]
texts = res.get('rec_texts', [])
scores = res.get('rec_scores', [])
logger.info(f"OCR found {len(texts)} text lines")
# 查找所有11-12位数字
pattern = re.compile(r'\d{11,12}')
candidates = []
for i, (text, score) in enumerate(zip(texts, scores)):
matches = pattern.findall(text.replace(" ", "").replace("-", ""))
for num in matches:
candidates.append({
'code': num,
'confidence': float(score),
'text': text,
'line': i
})
if not candidates:
logger.warning("No 11-12 digit numbers found in OCR results")
return None
# 优先选择以"2"开头的候选CMA码标准格式
candidates_starting_with_2 = [c for c in candidates if c['code'].startswith('2')]
if candidates_starting_with_2:
candidates_starting_with_2.sort(key=lambda x: x['confidence'], reverse=True)
best = candidates_starting_with_2[0]
logger.info(f"Best candidate (starts with '2'): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
return best
else:
candidates.sort(key=lambda x: x['confidence'], reverse=True)
best = candidates[0]
logger.info(f"Best candidate (no '2' prefix): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
return best
except Exception as e:
logger.error(f"OCR extraction failed: {e}")
return None
# ============ 测试函数 ============
def test_single_pdf(pdf_path: str, expected_cma: str = None):
"""测试单个PDF的CMA码提取"""
logger.info(f"\n{'#'*80}")
logger.info(f"TESTING: {Path(pdf_path).name}")
logger.info(f"Expected CMA: {expected_cma or 'Unknown'}")
logger.info(f"{'#'*80}\n")
# 提取页面
logger.info("Extracting PDF page...")
doc = fitz.open(pdf_path)
page = doc[0]
# 使用300 DPI渲染
mat = fitz.Matrix(300 / 72, 300 / 72)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
img_array = np.frombuffer(img_data, dtype=np.uint8)
page_img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
doc.close()
logger.info(f"Page size: {page_img.shape}")
# 初始化OCR
logger.info("Initializing PaddleOCR...")
ocr = PaddleOCR(lang='ch')
# 提取CMA码
extractor = SmartCMAExtractor(ocr)
result = extractor.extract(page_img, Path(pdf_path).name)
# 输出结果
logger.info("\n" + "="*80)
logger.info("FINAL RESULT")
logger.info("="*80)
logger.info(f"PDF: {result['pdf_name']}")
logger.info(f"Success: {result['success']}")
logger.info(f"Method: {result['method']}")
logger.info(f"CMA Code: {result.get('code', 'N/A')}")
logger.info(f"Confidence: {result.get('confidence', 0):.2f}")
if expected_cma:
if result['code'] == expected_cma:
logger.info(f"✓✓✓ CORRECT! Expected: {expected_cma}, Got: {result['code']}")
else:
logger.info(f"✗✗✗ WRONG! Expected: {expected_cma}, Got: {result['code']}")
logger.info("="*80 + "\n")
return result
# ============ 主程序 ============
if __name__ == "__main__":
# 测试YDQ23_001838.pdf
test_single_pdf(TEST_PDF, expected_cma="210020349096")
print("\n" + "="*80)
print("TEST COMPLETED")
print("="*80)
print(f"Results saved to: {OUTPUT_DIR}")
print(f" - test.log: Detailed log")
print(f" - roi.png: ROI image (if template matching succeeded)")