279 lines
9.4 KiB
Python
279 lines
9.4 KiB
Python
"""
|
|
Unit tests for CMA template matching improvements.
|
|
|
|
This module validates incremental improvements to the template matching algorithm
|
|
against known failure cases.
|
|
"""
|
|
import unittest
|
|
import cv2
|
|
import numpy as np
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants
|
|
CMA_LOGO_PATH = Path("template/CMA_Logo.png")
|
|
PDF_DIR = Path("src/test/resources/data/pdfs")
|
|
RESULTS_FILE = Path("src/test/resources/data/results.json")
|
|
|
|
# Test cases with expected CMA codes
|
|
TEST_CASES = {
|
|
"WTS2025-21283.pdf": "220020349627",
|
|
"YDQ23_001838.pdf": "210020349096",
|
|
"YDQ23_001850.pdf": "210020349096",
|
|
"YDQ25_001875.pdf": "240020349096",
|
|
"YDQ25_002294.pdf": "240020349096",
|
|
}
|
|
|
|
# Success cases (should match with high confidence)
|
|
SUCCESS_CASES = {
|
|
"1.pdf": "181122170342",
|
|
"YDQ25_001845.pdf": "240020349096",
|
|
}
|
|
|
|
|
|
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
|
|
"""cv2.imread replacement that supports paths with non-ASCII characters."""
|
|
try:
|
|
data = np.fromfile(str(path), dtype=np.uint8)
|
|
img = cv2.imdecode(data, flags)
|
|
return img
|
|
except Exception as e:
|
|
logger.error(f"Failed to read image {path}: {e}")
|
|
return None
|
|
|
|
|
|
def extract_pdf_page(pdf_path, page_num=0):
|
|
"""Extract a page from PDF as image."""
|
|
import fitz
|
|
try:
|
|
doc = fitz.open(str(pdf_path))
|
|
if page_num >= doc.page_count:
|
|
doc.close()
|
|
return None
|
|
page = doc[page_num]
|
|
|
|
# Render at 300 DPI for better quality
|
|
mat = fitz.Matrix(300 / 72, 300 / 72)
|
|
pix = page.get_pixmap(matrix=mat)
|
|
img_data = pix.tobytes("png")
|
|
img_array = np.frombuffer(img_data, dtype=np.uint8)
|
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
|
|
doc.close()
|
|
return img
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract page from {pdf_path}: {e}")
|
|
return None
|
|
|
|
|
|
def match_template_old(page_img, template, method=cv2.TM_CCOEFF_NORMED):
|
|
"""Original matching method: TM_CCOEFF_NORMED"""
|
|
if len(page_img.shape) == 3:
|
|
page_gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
page_gray = page_img
|
|
|
|
if len(template.shape) == 3:
|
|
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
template_gray = template
|
|
|
|
result = cv2.matchTemplate(page_gray, template_gray, method=method)
|
|
if result is None:
|
|
return None
|
|
|
|
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
|
match_center = (
|
|
max_loc[0] + template_gray.shape[1] // 2,
|
|
max_loc[1] + template_gray.shape[0] // 2
|
|
)
|
|
|
|
return {
|
|
'max_val': float(max_val),
|
|
'match_center': match_center,
|
|
'match_loc': max_loc,
|
|
'method': 'TM_CCOEFF_NORMED'
|
|
}
|
|
|
|
|
|
def match_template_new(page_img, template, method=cv2.TM_CCORR_NORMED):
|
|
"""Improved matching method: TM_CCORR_NORMED"""
|
|
if len(page_img.shape) == 3:
|
|
page_gray = cv2.cvtColor(page_img, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
page_gray = page_img
|
|
|
|
if len(template.shape) == 3:
|
|
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
template_gray = template
|
|
|
|
result = cv2.matchTemplate(page_gray, template_gray, method=method)
|
|
if result is None:
|
|
return None
|
|
|
|
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
|
match_center = (
|
|
max_loc[0] + template_gray.shape[1] // 2,
|
|
max_loc[1] + template_gray.shape[0] // 2
|
|
)
|
|
|
|
return {
|
|
'max_val': float(max_val),
|
|
'match_center': match_center,
|
|
'match_loc': max_loc,
|
|
'method': 'TM_CCORR_NORMED'
|
|
}
|
|
|
|
|
|
class TestTemplateMatching(unittest.TestCase):
|
|
"""Test cases for template matching improvements."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Load template once for all tests."""
|
|
cls.template = imread_unicode(CMA_LOGO_PATH, cv2.IMREAD_COLOR)
|
|
if cls.template is None:
|
|
raise unittest.SkipTest(f"Could not load template from {CMA_LOGO_PATH}")
|
|
logger.info(f"Loaded template: {cls.template.shape}")
|
|
|
|
def test_specific_failures(self):
|
|
"""Test known failure cases (confidence 0.32-0.39)."""
|
|
results = {}
|
|
|
|
for pdf_name, expected_cma in TEST_CASES.items():
|
|
pdf_path = PDF_DIR / pdf_name
|
|
if not pdf_path.exists():
|
|
self.skipTest(f"PDF not found: {pdf_path}")
|
|
|
|
with self.subTest(pdf=pdf_name):
|
|
img = extract_pdf_page(pdf_path)
|
|
self.assertIsNotNone(img, f"Failed to extract page from {pdf_name}")
|
|
|
|
# Test old method
|
|
result_old = match_template_old(img, self.template)
|
|
self.assertIsNotNone(result_old, f"Old method returned None for {pdf_name}")
|
|
|
|
# Test new method
|
|
result_new = match_template_new(img, self.template)
|
|
self.assertIsNotNone(result_new, f"New method returned None for {pdf_name}")
|
|
|
|
# Log results
|
|
logger.info(f"{pdf_name}:")
|
|
logger.info(f" Old ({result_old['method']}): {result_old['max_val']:.3f}")
|
|
logger.info(f" New ({result_new['method']}): {result_new['max_val']:.3f}")
|
|
|
|
# Store results
|
|
results[pdf_name] = {
|
|
'expected_cma': expected_cma,
|
|
'old_confidence': result_old['max_val'],
|
|
'new_confidence': result_new['max_val'],
|
|
}
|
|
|
|
# Verify new method doesn't decrease confidence significantly
|
|
# Allow small decrease (0.02) but overall should improve
|
|
self.assertGreaterEqual(
|
|
result_new['max_val'],
|
|
result_old['max_val'] - 0.02,
|
|
f"{pdf_name}: New method should not significantly decrease confidence"
|
|
)
|
|
|
|
# Print summary
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("FAILURE CASES SUMMARY")
|
|
logger.info("=" * 60)
|
|
for pdf_name, data in results.items():
|
|
logger.info(f"{pdf_name}:")
|
|
logger.info(f" Expected CMA: {data['expected_cma']}")
|
|
logger.info(f" Old: {data['old_confidence']:.3f}")
|
|
logger.info(f" New: {data['new_confidence']:.3f}")
|
|
logger.info(f" Improvement: {data['new_confidence'] - data['old_confidence']:+.3f}")
|
|
|
|
def test_success_cases(self):
|
|
"""Test known success cases (should match with high confidence)."""
|
|
results = {}
|
|
|
|
for pdf_name, expected_cma in SUCCESS_CASES.items():
|
|
pdf_path = PDF_DIR / pdf_name
|
|
if not pdf_path.exists():
|
|
self.skipTest(f"PDF not found: {pdf_path}")
|
|
|
|
with self.subTest(pdf=pdf_name):
|
|
img = extract_pdf_page(pdf_path)
|
|
self.assertIsNotNone(img, f"Failed to extract page from {pdf_name}")
|
|
|
|
# Test both methods
|
|
result_old = match_template_old(img, self.template)
|
|
result_new = match_template_new(img, self.template)
|
|
|
|
self.assertIsNotNone(result_old)
|
|
self.assertIsNotNone(result_new)
|
|
|
|
# Log results
|
|
logger.info(f"{pdf_name}:")
|
|
logger.info(f" Old: {result_old['max_val']:.3f}")
|
|
logger.info(f" New: {result_new['max_val']:.3f}")
|
|
|
|
results[pdf_name] = {
|
|
'expected_cma': expected_cma,
|
|
'old_confidence': result_old['max_val'],
|
|
'new_confidence': result_new['max_val'],
|
|
}
|
|
|
|
# Both methods should find the template with high confidence
|
|
self.assertGreater(
|
|
result_old['max_val'],
|
|
0.30,
|
|
f"{pdf_name}: Old method should find template with confidence > 0.30"
|
|
)
|
|
self.assertGreater(
|
|
result_new['max_val'],
|
|
0.30,
|
|
f"{pdf_name}: New method should find template with confidence > 0.30"
|
|
)
|
|
|
|
# Print summary
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("SUCCESS CASES SUMMARY")
|
|
logger.info("=" * 60)
|
|
for pdf_name, data in results.items():
|
|
logger.info(f"{pdf_name}:")
|
|
logger.info(f" Expected CMA: {data['expected_cma']}")
|
|
logger.info(f" Old: {data['old_confidence']:.3f}")
|
|
logger.info(f" New: {data['new_confidence']:.3f}")
|
|
|
|
def test_threshold_comparison(self):
|
|
"""Test how changing threshold affects match detection."""
|
|
# Test various thresholds
|
|
thresholds = [0.25, 0.30, 0.35, 0.40]
|
|
|
|
for threshold in thresholds:
|
|
detected = 0
|
|
total = 0
|
|
|
|
for pdf_name in list(TEST_CASES.keys()) + list(SUCCESS_CASES.keys()):
|
|
pdf_path = PDF_DIR / pdf_name
|
|
if not pdf_path.exists():
|
|
continue
|
|
|
|
img = extract_pdf_page(pdf_path)
|
|
if img is None:
|
|
continue
|
|
|
|
total += 1
|
|
result_new = match_template_new(img, self.template)
|
|
|
|
if result_new and result_new['max_val'] >= threshold:
|
|
detected += 1
|
|
|
|
logger.info(f"Threshold {threshold:.2f}: {detected}/{total} detected ({detected/total*100:.1f}%)")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Run tests with verbose output
|
|
unittest.main(verbosity=2)
|