report-detect/archive/temp_scripts/test_improved_extraction.py

425 lines
16 KiB
Python
Raw Normal View History

chore(project): conservative cleanup - archive temp scripts and old docs Major cleanup to improve project organization and maintainability. Changes: - Moved 34 temp/debug/test scripts to archive/temp_scripts/ - Moved 9 auxiliary tools to archive/tools/ - Moved 3 CRT test scripts to archive/crt_tests/ - Moved 4 OCR test scripts to archive/ocr_tests/ - Moved 14 old documentation files to archive/docs/ - Deleted 4 useless files (duplicates, temp files) Root directory: - Before: 67 files (cluttered) - After: 10 core files (clean and organized) Core files retained: - test_accuracy_batch_full.py (main script) - cma_extraction_template_primary.py (CMA extraction) - cma_extraction_final.py (backup CMA extraction) - CLAUDE.md (project guide) - TEST_ACCURACY_BATCH_README.md (usage guide) - TEST_ACCURACY_BATCH_DEPENDENCIES.md (dependency docs) - CLEANUP_PLAN.md (cleanup plan) - CLEANUP_SUMMARY.md (this file) - IMPLEMENTATION_SUMMARY.md (implementation summary) - requirements.txt (dependencies) Archive structure: archive/ ├── temp_scripts/ (34 files: test_, debug_, analyze_, etc.) ├── tools/ (9 files: find_, show_, visualize_, etc.) ├── crt_tests/ (3 files: CRT extraction tests) ├── ocr_tests/ (4 files: OCR timeout tests) └── docs/ (14 files: old reports and guides) Benefits: ✓ Cleaner root directory - easier navigation ✓ Better organization - clear separation of concerns ✓ Preserved history - all files archived, not deleted ✓ Improved maintainability - easier to find active files ✓ Better git history - removed 198 deleted files from tracking No functional changes - all core functionality preserved. Related: - TEST_ACCURACY_BATCH_DEPENDENCIES.md - dependency analysis - CLEANUP_PLAN.md - detailed cleanup plan Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 14:35:06 +08:00
"""
改进的CMA码提取测试 - 结合方案2和方案3
方案2: 智能fallback机制 - 当模板匹配失效时自动使用全页OCR
方案3: 调整模板匹配参数 - 添加预处理多尺度多方法尝试
"""
import sys
import os
import cv2
import numpy as np
import fitz
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
from paddleocr import PaddleOCR
# ============ 配置 ============
# 测试PDF
TEST_PDF = "src/test/resources/data/pdfs/YDQ23_001838.pdf"
TEMPLATE_PATH = "template/CMA_Logo.png"
OUTPUT_DIR = Path("test_improved_extraction")
OUTPUT_DIR.mkdir(exist_ok=True)
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(OUTPUT_DIR / "test.log", encoding='utf-8')
]
)
logger = logging.getLogger(__name__)
# ============ 方案3: 改进的模板匹配 ============
class ImprovedTemplateMatcher:
"""改进的模板匹配器 - 结合多种方法和预处理"""
def __init__(self, template_path: str):
self.template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
if self.template is None:
raise ValueError(f"Cannot load template from {template_path}")
self.template_h, self.template_w = self.template.shape[:2]
logger.info(f"Template loaded: {self.template_w}x{self.template_h}")
def preprocess_page(self, page_img: np.ndarray) -> Dict[str, np.ndarray]:
"""预处理页面图像,生成多个版本用于匹配"""
gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY) if len(page_img.shape) == 3 else page_img
processed = {
'original': gray,
'blurred': cv2.GaussianBlur(gray, (5, 5), 0),
'denoised': cv2.fastNlMeansDenoising(gray, None, 10, 7, 21),
'equalized': cv2.equalizeHist(gray),
'clahe': cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray),
}
# 添加边缘增强版本(对圆形标志有帮助)
edges = cv2.Canny(gray, 50, 150)
processed['edges'] = edges
logger.info(f"Generated {len(processed)} preprocessed versions")
return processed
def match_multi_method(
self,
page_img: np.ndarray,
scales: List[float] = [0.8, 0.9, 1.0, 1.1, 1.2],
methods: List[int] = [cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF]
) -> Dict:
"""
使用多种方法和尺度进行模板匹配
Returns:
{
'success': bool,
'best_match': {'confidence': float, 'location': tuple, 'method': str, 'scale': float, 'preprocessing': str},
'all_matches': List[Dict],
'num_matches': int
}
"""
h, w = page_img.shape[:2]
max_y_threshold = int(h * 0.6) # 只接受页面上半部分的匹配
# 预处理页面
preprocessed = self.preprocess_page(page_img)
all_matches = []
num_total_checks = 0
for prep_name, processed_img in preprocessed.items():
for scale in scales:
# 调整模板大小
if scale != 1.0:
new_w = int(self.template_w * scale)
new_h = int(self.template_h * scale)
if new_w < 10 or new_h < 10:
continue
scaled_template = cv2.resize(self.template, (new_w, new_h), interpolation=cv2.INTER_AREA)
else:
scaled_template = self.template
new_h, new_w = self.template_h, self.template_w
for method in methods:
num_total_checks += 1
try:
result = cv2.matchTemplate(processed_img, scaled_template, method)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
# 计算匹配中心位置
match_center_y = max_loc[1] + new_h // 2
# 位置过滤:只接受页面上半部分的匹配
if match_center_y > max_y_threshold:
continue
match_info = {
'confidence': float(max_val),
'location': max_loc,
'center': (max_loc[0] + new_w // 2, max_loc[1] + new_h // 2),
'method': method,
'scale': scale,
'preprocessing': prep_name,
'template_size': (new_w, new_h)
}
all_matches.append(match_info)
except Exception as e:
logger.debug(f"Match failed: prep={prep_name}, scale={scale}, method={method}, error={e}")
continue
logger.info(f"Total match attempts: {num_total_checks}")
logger.info(f"Valid matches (above threshold, in upper 60%): {len(all_matches)}")
if not all_matches:
return {
'success': False,
'reason': 'No valid matches found',
'num_matches': 0
}
# 按置信度排序
all_matches.sort(key=lambda x: x['confidence'], reverse=True)
# 统计每个位置附近的匹配数量(用于检测匹配失效)
best_match = all_matches[0]
match_positions = [(m['center'][0], m['center'][1]) for m in all_matches[:10]]
# 检查是否有过多匹配(可能意味着模板匹配失效)
if len(all_matches) > 1000:
logger.warning(f"Too many matches ({len(all_matches)}), template matching may have failed")
return {
'success': True,
'best_match': best_match,
'all_matches': all_matches,
'num_matches': len(all_matches)
}
def is_matching_failed(self, match_result: Dict) -> bool:
"""
判断模板匹配是否失效
失效的迹象
1. 匹配数量过多>1000- 说明模板匹配了太多地方
2. 所有匹配的置信度都很高且接近 - 说明可能是噪声
3. 匹配位置分散在整个页面
"""
if not match_result.get('success'):
return True
num_matches = match_result.get('num_matches', 0)
best_confidence = match_result['best_match']['confidence']
# 检查1: 匹配数量过多
if num_matches > 1000:
logger.warning(f"Template matching failed: {num_matches} matches (threshold: >1000)")
return True
# 检查2: 置信度异常高且匹配数量多
if num_matches > 100 and best_confidence > 0.9:
logger.warning(f"Template matching failed: high confidence ({best_confidence:.3f}) with many matches ({num_matches})")
return True
return False
# ============ 方案2: 智能Fallback提取器 ============
class SmartCMAExtractor:
"""智能CMA码提取器 - 结合模板匹配和全页OCR"""
def __init__(self, ocr_engine: PaddleOCR):
self.ocr = ocr_engine
self.matcher = ImprovedTemplateMatcher(TEMPLATE_PATH)
def extract(self, page_img: np.ndarray, pdf_name: str) -> Dict:
"""
智能提取CMA码
1. 尝试改进的模板匹配
2. 检测匹配是否失效
3. 如果失效使用全页OCR fallback
"""
result = {
'pdf_name': pdf_name,
'success': False,
'code': None,
'confidence': 0.0,
'method': None,
'match_result': None
}
logger.info(f"\n{'='*80}")
logger.info(f"EXTRACTING FROM: {pdf_name}")
logger.info(f"{'='*80}")
# 步骤1: 尝试改进的模板匹配
logger.info("\n[Step 1] Attempting improved template matching...")
match_result = self.matcher.match_multi_method(page_img)
if match_result['success']:
best_match = match_result['best_match']
logger.info(f"Template match found:")
logger.info(f" Confidence: {best_match['confidence']:.3f}")
logger.info(f" Location: {best_match['center']}")
logger.info(f" Method: {best_match['method']}")
logger.info(f" Scale: {best_match['scale']}")
logger.info(f" Preprocessing: {best_match['preprocessing']}")
logger.info(f" Total matches: {match_result['num_matches']}")
result['match_result'] = match_result
# 检查匹配是否失效
if self.matcher.is_matching_failed(match_result):
logger.warning("⚠️ Template matching FAILED - using full-page OCR fallback")
result['method'] = 'fullpage_fallback'
return self._extract_fullpage(page_img, result)
else:
logger.info("✓ Template matching appears valid, extracting from ROI...")
return self._extract_from_roi(page_img, best_match, result)
else:
logger.warning(f"⚠️ No template match found - reason: {match_result.get('reason')}")
logger.info("→ Using full-page OCR fallback")
result['method'] = 'fullpage_fallback'
return self._extract_fullpage(page_img, result)
def _extract_from_roi(self, page_img: np.ndarray, match_info: Dict, result: Dict) -> Dict:
"""从ROI区域提取CMA码"""
# 计算ROIlogo右侧
x, y = match_info['center']
template_w, template_h = match_info['template_size']
h, w = page_img.shape[:2]
# ROI: logo右侧向下延伸
roi_x1 = max(0, x)
roi_y1 = max(0, y - template_h // 2)
roi_x2 = min(w, x + min(600, w - x))
roi_y2 = min(h, y + template_h * 4)
logger.info(f"ROI: ({roi_x1}, {roi_y1}) -> ({roi_x2}, {roi_y2})")
logger.info(f"ROI size: {roi_x2 - roi_x1}x{roi_y2 - roi_y1}")
roi_img = page_img[roi_y1:roi_y2, roi_x1:roi_x2]
# 保存ROI
cv2.imwrite(str(OUTPUT_DIR / "roi.png"), roi_img)
# OCR提取
cma_code = self._extract_cma_from_ocr_result(roi_img)
if cma_code:
result['success'] = True
result['code'] = cma_code['code']
result['confidence'] = cma_code['confidence']
result['method'] = 'template_matching'
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
else:
logger.warning("ROI extraction failed, trying full-page OCR fallback...")
return self._extract_fullpage(page_img, result)
return result
def _extract_fullpage(self, page_img: np.ndarray, result: Dict) -> Dict:
"""全页OCR fallback"""
logger.info("\n[Step 2] Running full-page OCR fallback...")
cma_code = self._extract_cma_from_ocr_result(page_img)
if cma_code:
result['success'] = True
result['code'] = cma_code['code']
result['confidence'] = cma_code['confidence']
result['method'] = 'fullpage_ocr'
logger.info(f"✓ SUCCESS: Found CMA code: {cma_code['code']} (confidence: {cma_code['confidence']:.2f})")
else:
result['method'] = 'failed'
logger.error("✗ FAILED: Full-page OCR also failed")
return result
def _extract_cma_from_ocr_result(self, img: np.ndarray) -> Optional[Dict]:
"""从OCR结果中提取CMA码"""
try:
ocr_result = self.ocr.predict(img)
if not ocr_result or len(ocr_result) == 0:
logger.warning("OCR returned no results")
return None
res = ocr_result[0]
texts = res.get('rec_texts', [])
scores = res.get('rec_scores', [])
logger.info(f"OCR found {len(texts)} text lines")
# 查找所有11-12位数字
pattern = re.compile(r'\d{11,12}')
candidates = []
for i, (text, score) in enumerate(zip(texts, scores)):
matches = pattern.findall(text.replace(" ", "").replace("-", ""))
for num in matches:
candidates.append({
'code': num,
'confidence': float(score),
'text': text,
'line': i
})
if not candidates:
logger.warning("No 11-12 digit numbers found in OCR results")
return None
# 优先选择以"2"开头的候选CMA码标准格式
candidates_starting_with_2 = [c for c in candidates if c['code'].startswith('2')]
if candidates_starting_with_2:
candidates_starting_with_2.sort(key=lambda x: x['confidence'], reverse=True)
best = candidates_starting_with_2[0]
logger.info(f"Best candidate (starts with '2'): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
return best
else:
candidates.sort(key=lambda x: x['confidence'], reverse=True)
best = candidates[0]
logger.info(f"Best candidate (no '2' prefix): {best['code']} (line {best['line']}, conf: {best['confidence']:.2f})")
return best
except Exception as e:
logger.error(f"OCR extraction failed: {e}")
return None
# ============ 测试函数 ============
def test_single_pdf(pdf_path: str, expected_cma: str = None):
"""测试单个PDF的CMA码提取"""
logger.info(f"\n{'#'*80}")
logger.info(f"TESTING: {Path(pdf_path).name}")
logger.info(f"Expected CMA: {expected_cma or 'Unknown'}")
logger.info(f"{'#'*80}\n")
# 提取页面
logger.info("Extracting PDF page...")
doc = fitz.open(pdf_path)
page = doc[0]
# 使用300 DPI渲染
mat = fitz.Matrix(300 / 72, 300 / 72)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
img_array = np.frombuffer(img_data, dtype=np.uint8)
page_img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
doc.close()
logger.info(f"Page size: {page_img.shape}")
# 初始化OCR
logger.info("Initializing PaddleOCR...")
ocr = PaddleOCR(lang='ch')
# 提取CMA码
extractor = SmartCMAExtractor(ocr)
result = extractor.extract(page_img, Path(pdf_path).name)
# 输出结果
logger.info("\n" + "="*80)
logger.info("FINAL RESULT")
logger.info("="*80)
logger.info(f"PDF: {result['pdf_name']}")
logger.info(f"Success: {result['success']}")
logger.info(f"Method: {result['method']}")
logger.info(f"CMA Code: {result.get('code', 'N/A')}")
logger.info(f"Confidence: {result.get('confidence', 0):.2f}")
if expected_cma:
if result['code'] == expected_cma:
logger.info(f"✓✓✓ CORRECT! Expected: {expected_cma}, Got: {result['code']}")
else:
logger.info(f"✗✗✗ WRONG! Expected: {expected_cma}, Got: {result['code']}")
logger.info("="*80 + "\n")
return result
# ============ 主程序 ============
if __name__ == "__main__":
# 测试YDQ23_001838.pdf
test_single_pdf(TEST_PDF, expected_cma="210020349096")
print("\n" + "="*80)
print("TEST COMPLETED")
print("="*80)
print(f"Results saved to: {OUTPUT_DIR}")
print(f" - test.log: Detailed log")
print(f" - roi.png: ROI image (if template matching succeeded)")