feat(cma): add CMA extraction module fallback implementation
Add cma_extraction_final.py as backup CMA extraction module. This module provides fallback CMA code extraction when the primary template-based method (cma_extraction_template_primary.py) fails. Features: - Full-page OCR extraction as fallback - CMA pattern matching (11-12 digit codes) - Integration with main batch testing script - Supports both template matching and OCR-only approaches Usage: The main script (test_accuracy_batch_full.py) automatically falls back to this module if template matching fails: 1. Primary: cma_extraction_template_primary.py (template matching) 2. Fallback: cma_extraction_final.py (full-page OCR) Related files: - cma_extraction_template_primary.py (primary module) - test_accuracy_batch_full.py (main script that uses both) - TEST_ACCURACY_BATCH_DEPENDENCIES.md (dependency documentation) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
5f72e010cd
commit
9562cf1ac7
|
|
@ -0,0 +1,415 @@
|
|||
"""
|
||||
CMA Code Extraction Module using Full-Page OCR with Position Filtering
|
||||
|
||||
This module provides a robust method for extracting CMA certification codes
|
||||
from PDF report pages using PaddleOCR with position-based filtering.
|
||||
|
||||
Approach:
|
||||
1. Run full-page OCR to get all text with positions
|
||||
2. Filter for text in top-right area (where CMA logo/code is typically located)
|
||||
3. Use regex patterns to find CMA code (11 digits starting with '2')
|
||||
4. Score candidates by: position, confidence, code format
|
||||
|
||||
Author: Based on reference implementation from refer/认监-扫描件识别
|
||||
Date: 2025-02-05
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CMA code patterns
|
||||
PATTERN_PRIMARY = r'2[0-9]{10,11}' # 11-12 digits starting with 2
|
||||
PATTERN_FALLBACK = r'[0-9]{11,12}' # any 11-12 digits
|
||||
|
||||
|
||||
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
|
||||
"""
|
||||
cv2.imread replacement that supports paths with non-ASCII characters.
|
||||
|
||||
Args:
|
||||
path: Image file path (may contain Chinese characters)
|
||||
flags: cv2.IMREAD_* flags
|
||||
|
||||
Returns:
|
||||
Image as numpy array or None if failed
|
||||
"""
|
||||
try:
|
||||
data = np.fromfile(str(path), dtype=np.uint8)
|
||||
img = cv2.imdecode(data, flags)
|
||||
return img
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read image {path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def score_cma_candidate(candidate, page_width, page_height):
|
||||
"""
|
||||
Score a CMA code candidate based on multiple factors.
|
||||
|
||||
Scoring factors:
|
||||
- Code format: 11 digits starting with '2' gets highest score
|
||||
- Position: Top-right area gets bonus
|
||||
- Confidence: Higher OCR confidence increases score
|
||||
- Length: Exact 11 digits gets bonus
|
||||
|
||||
Args:
|
||||
candidate: Dict with 'code', 'confidence', 'position', 'text'
|
||||
page_width: Page width in pixels
|
||||
page_height: Page height in pixels
|
||||
|
||||
Returns:
|
||||
Score (higher is better)
|
||||
"""
|
||||
score = 0
|
||||
code = candidate['code']
|
||||
confidence = candidate['confidence']
|
||||
pos_x, pos_y = candidate['position']
|
||||
|
||||
# Format score: 11 digits starting with '2' is perfect
|
||||
if len(code) == 11 and code.startswith('2'):
|
||||
score += 100
|
||||
elif code.startswith('2'):
|
||||
score += 50
|
||||
|
||||
# Length bonus (11-12 digits are standard)
|
||||
if len(code) in (11, 12):
|
||||
score += 20
|
||||
elif 10 <= len(code) <= 13:
|
||||
score += 10
|
||||
|
||||
# Position bonus: prefer top-right area (typical CMA logo location)
|
||||
if pos_x > page_width * 0.5 and pos_y < page_height * 0.35:
|
||||
score += 30
|
||||
|
||||
# Confidence bonus (0-10)
|
||||
score += confidence * 10
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def extract_cma_code_fullpage(page_img, ocr_engine, output_dir=None):
|
||||
"""
|
||||
Extract CMA code from a PDF page image using full-page OCR with position filtering.
|
||||
|
||||
This is the recommended method for CMA extraction as it's more robust than
|
||||
template matching and works even when the CMA logo is degraded.
|
||||
|
||||
Args:
|
||||
page_img: Page image (numpy array or path to image)
|
||||
ocr_engine: Initialized PaddleOCR instance
|
||||
output_dir: Optional directory to save debug visualizations
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- 'code': Extracted CMA code (str or None)
|
||||
- 'confidence': OCR confidence (float)
|
||||
- 'raw_text': Raw OCR text containing the code (str)
|
||||
- 'position': (x, y) tuple of code position
|
||||
- 'box': Bounding box [x1, y1, x2, y2]
|
||||
- 'success': Boolean indicating successful extraction
|
||||
"""
|
||||
result = {
|
||||
'code': None,
|
||||
'confidence': 0.0,
|
||||
'raw_text': '',
|
||||
'position': (0, 0),
|
||||
'box': None,
|
||||
'success': False
|
||||
}
|
||||
|
||||
# Load image if path provided
|
||||
if isinstance(page_img, str):
|
||||
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
|
||||
elif isinstance(page_img, np.ndarray):
|
||||
image = page_img
|
||||
else:
|
||||
logger.error(f"Invalid image type: {type(page_img)}")
|
||||
return result
|
||||
|
||||
if image is None or image.size == 0:
|
||||
logger.error("Failed to load image or empty image")
|
||||
return result
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Run OCR
|
||||
logger.info("Running full-page OCR for CMA extraction...")
|
||||
try:
|
||||
# Check for legacy PaddleOCR
|
||||
if hasattr(ocr_engine, 'ocr'):
|
||||
# Legacy PaddleOCR.ocr returns [ [ [box, (text, score)], ... ] ]
|
||||
# Try simple .ocr() call first (without cls parameter for better compatibility)
|
||||
try:
|
||||
raw_result = ocr_engine.ocr(image)
|
||||
except Exception as ocr_err:
|
||||
logger.warning(f".ocr() method failed: {ocr_err}, trying .predict()...")
|
||||
raw_result = None
|
||||
|
||||
# Fallback to .predict() if .ocr() failed
|
||||
if hasattr(ocr_engine, 'predict'):
|
||||
try:
|
||||
raw_result = ocr_engine.predict(image)
|
||||
except Exception as pred_err:
|
||||
logger.error(f".predict() also failed: {pred_err}")
|
||||
return result
|
||||
|
||||
if raw_result is None:
|
||||
logger.error("OCR returned None")
|
||||
return result
|
||||
|
||||
# Initialize lists
|
||||
rec_texts = []
|
||||
rec_scores = []
|
||||
rec_boxes = []
|
||||
|
||||
# Validate raw_result structure
|
||||
if not isinstance(raw_result, list):
|
||||
logger.error(f"OCR returned unexpected type: {type(raw_result)}")
|
||||
return result
|
||||
|
||||
if len(raw_result) == 0:
|
||||
logger.warning("OCR returned empty list")
|
||||
return result
|
||||
|
||||
if raw_result[0] is None:
|
||||
logger.warning("OCR result[0] is None")
|
||||
return result
|
||||
|
||||
if not isinstance(raw_result[0], list):
|
||||
logger.error(f"OCR result[0] is not a list: {type(raw_result[0])}")
|
||||
return result
|
||||
|
||||
for line_idx, line in enumerate(raw_result[0]):
|
||||
# line: [box, (text, score)] or [box, text]
|
||||
try:
|
||||
if not isinstance(line, (list, tuple)) or len(line) < 2:
|
||||
logger.debug(f"Skipping line {line_idx}: invalid format")
|
||||
continue
|
||||
|
||||
box = line[0]
|
||||
|
||||
# Validate box before processing
|
||||
if box is None:
|
||||
logger.debug(f"Skipping line {line_idx}: box is None")
|
||||
continue
|
||||
|
||||
# Extract text and score
|
||||
if isinstance(line[1], (list, tuple)) and len(line[1]) >= 2:
|
||||
text, score = line[1]
|
||||
elif isinstance(line[1], (list, tuple)) and len(line[1]) == 1:
|
||||
text = line[1][0]
|
||||
score = 0.99
|
||||
elif isinstance(line[1], str):
|
||||
text = line[1]
|
||||
score = 0.99
|
||||
else:
|
||||
text = str(line[1])
|
||||
score = 0.99
|
||||
|
||||
rec_texts.append(text)
|
||||
rec_scores.append(score)
|
||||
|
||||
# Convert box to [x1, y1, x2, y2] with proper validation
|
||||
# box is [[x,y], [x,y], [x,y], [x,y]]
|
||||
try:
|
||||
if isinstance(box, (list, tuple)) and len(box) >= 4:
|
||||
# Check if box is in [[x,y], [x,y], ...] format
|
||||
if all(isinstance(pt, (list, tuple)) and len(pt) >= 2 for pt in box[:4]):
|
||||
xs = [pt[0] for pt in box[:4]]
|
||||
ys = [pt[1] for pt in box[:4]]
|
||||
rec_boxes.append([min(xs), min(ys), max(xs), max(ys)])
|
||||
else:
|
||||
# Box might be in [x1, y1, x2, y2, ...] format
|
||||
rec_boxes.append([box[0], box[1], box[2], box[3]])
|
||||
else:
|
||||
logger.warning(f"Line {line_idx}: Invalid box format: {type(box)}, len={len(box) if hasattr(box, '__len__') else 'N/A'}")
|
||||
rec_boxes.append(None) # Placeholder to maintain alignment
|
||||
except Exception as box_err:
|
||||
logger.warning(f"Line {line_idx}: Failed to parse box: {box_err}")
|
||||
rec_boxes.append(None) # Placeholder to maintain alignment
|
||||
|
||||
except (IndexError, TypeError, ValueError) as e:
|
||||
logger.warning(f"Skipped malformed OCR line {line_idx}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Found {len(rec_texts)} text lines (Legacy API)")
|
||||
|
||||
else:
|
||||
# Assume Paddlex or similar API (PaddleOCR 2.7+)
|
||||
ocr_result = ocr_engine.predict(image)
|
||||
|
||||
# Debug: Log result type
|
||||
logger.debug(f"OCR result type: {type(ocr_result)}")
|
||||
logger.debug(f"OCR result: {ocr_result}")
|
||||
|
||||
# Handle different return types
|
||||
if ocr_result is None:
|
||||
logger.error("OCR returned None")
|
||||
return result
|
||||
|
||||
# Check if result is a list (old API) or single object (new API)
|
||||
if isinstance(ocr_result, list):
|
||||
if len(ocr_result) == 0:
|
||||
logger.error("OCR returned empty list")
|
||||
return result
|
||||
ocr_data = ocr_result[0]
|
||||
else:
|
||||
# New API: result is already an OCRResult object
|
||||
ocr_data = ocr_result
|
||||
|
||||
# Extract data from OCRResult object
|
||||
# PaddleOCR 3.4+ uses nested OCRResult structure
|
||||
# The object may behave differently on subsequent calls, so use dict-style access
|
||||
|
||||
# Try to get data as dictionary (most reliable method)
|
||||
try:
|
||||
if hasattr(ocr_data, 'keys'):
|
||||
# Can use dict-like access
|
||||
rec_texts = ocr_data.get('rec_texts', [])
|
||||
rec_scores = ocr_data.get('rec_scores', [])
|
||||
rec_boxes = ocr_data.get('rec_boxes', [])
|
||||
elif hasattr(ocr_data, '__getitem__'):
|
||||
# Try to access as dict
|
||||
rec_texts = ocr_data['rec_texts'] if 'rec_texts' in ocr_data else []
|
||||
rec_scores = ocr_data['rec_scores'] if 'rec_scores' in ocr_data else []
|
||||
rec_boxes = ocr_data['rec_boxes'] if 'rec_boxes' in ocr_data else []
|
||||
else:
|
||||
# Fallback: try attribute access
|
||||
rec_texts = getattr(ocr_data, 'rec_texts', [])
|
||||
rec_scores = getattr(ocr_data, 'rec_scores', [])
|
||||
rec_boxes = getattr(ocr_data, 'rec_boxes', [])
|
||||
except Exception as parse_error:
|
||||
logger.error(f"Failed to extract OCR data: {parse_error}")
|
||||
logger.error(f"OCRResult type: {type(ocr_data)}")
|
||||
logger.error(f"OCRResult dir: {[a for a in dir(ocr_data) if not a.startswith('_')]}")
|
||||
return result
|
||||
|
||||
logger.debug(f"Extracted rec_texts: {rec_texts}")
|
||||
logger.debug(f"Extracted rec_scores: {rec_scores}")
|
||||
logger.debug(f"Extracted rec_boxes type: {type(rec_boxes)}")
|
||||
|
||||
logger.info(f"Found {len(rec_texts)} text lines")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR failed: {e}")
|
||||
return result
|
||||
|
||||
# Find CMA code candidates
|
||||
cma_candidates = []
|
||||
|
||||
# Debug: Log the data we got
|
||||
logger.debug(f"rec_texts type: {type(rec_texts)}, length: {len(rec_texts) if hasattr(rec_texts, '__len__') else 'N/A'}")
|
||||
logger.debug(f"rec_scores type: {type(rec_scores)}, length: {len(rec_scores) if hasattr(rec_scores, '__len__') else 'N/A'}")
|
||||
logger.debug(f"rec_boxes type: {type(rec_boxes)}, length: {len(rec_boxes) if hasattr(rec_boxes, '__len__') else 'N/A'}")
|
||||
|
||||
for i, text in enumerate(rec_texts):
|
||||
# Ensure text is a string
|
||||
if not isinstance(text, str):
|
||||
logger.warning(f"Skip non-string text at index {i}: type={type(text)}, value={text}")
|
||||
continue
|
||||
|
||||
if not text or len(text.strip()) == 0:
|
||||
logger.debug(f"Skip empty text at index {i}")
|
||||
continue
|
||||
|
||||
confidence = rec_scores[i] if i < len(rec_scores) else 0.5
|
||||
box = rec_boxes[i] if i < len(rec_boxes) else None
|
||||
|
||||
# Calculate center position (box format: [x1, y1, x2, y2])
|
||||
if box is not None and len(box) == 4:
|
||||
x1, y1, x2, y2 = box
|
||||
center_x = (x1 + x2) / 2
|
||||
center_y = (y1 + y2) / 2
|
||||
else:
|
||||
center_x = center_y = 0
|
||||
|
||||
# Extract numbers from text (with error handling)
|
||||
try:
|
||||
cleaned_text = text.replace(" ", "").replace("-", "")
|
||||
numbers = re.findall(r'[0-9]+', cleaned_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing text at index {i}: {e}, text='{text}', type={type(text)}")
|
||||
continue
|
||||
|
||||
for num in numbers:
|
||||
if 10 <= len(num) <= 12: # CMA codes are typically 11 digits
|
||||
cma_candidates.append({
|
||||
'code': num,
|
||||
'confidence': confidence,
|
||||
'text': text,
|
||||
'position': (center_x, center_y),
|
||||
'box': box,
|
||||
'index': i
|
||||
})
|
||||
|
||||
if not cma_candidates:
|
||||
logger.warning("No CMA code candidates found")
|
||||
return result
|
||||
|
||||
# Score and sort candidates
|
||||
cma_candidates.sort(key=lambda c: score_cma_candidate(c, w, h), reverse=True)
|
||||
best = cma_candidates[0]
|
||||
|
||||
# Update result
|
||||
result['code'] = best['code']
|
||||
result['confidence'] = best['confidence']
|
||||
result['raw_text'] = best['text']
|
||||
result['position'] = best['position']
|
||||
result['box'] = best['box']
|
||||
result['success'] = True
|
||||
|
||||
logger.info(f"Extracted CMA code: {best['code']} (confidence: {best['confidence']:.4f})")
|
||||
|
||||
# Visualize if output_dir provided
|
||||
if output_dir and best['box'] is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
viz = image.copy()
|
||||
|
||||
x1, y1, x2, y2 = [int(v) for v in best['box']]
|
||||
cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3)
|
||||
cv2.putText(viz, f"CMA: {best['code']}", (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
||||
|
||||
viz_path = os.path.join(output_dir, "cma_detection_fullpage.png")
|
||||
cv2.imwrite(viz_path, viz)
|
||||
logger.info(f"Saved visualization: {viz_path}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the CMA extraction
|
||||
import sys
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python cma_extraction_final.py <image_path> [output_dir]")
|
||||
sys.exit(1)
|
||||
|
||||
img_path = sys.argv[1]
|
||||
out_dir = sys.argv[2] if len(sys.argv) > 2 else "cma_test_output"
|
||||
|
||||
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
print("Initializing PaddleOCR...")
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
|
||||
|
||||
result = extract_cma_code_fullpage(img_path, ocr, out_dir)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("CMA EXTRACTION RESULT")
|
||||
print("=" * 60)
|
||||
print(f"Success: {result['success']}")
|
||||
if result['success']:
|
||||
print(f"CMA Code: {result['code']}")
|
||||
print(f"Confidence: {result['confidence']:.4f}")
|
||||
print(f"Raw Text: '{result['raw_text']}'")
|
||||
print(f"Position: {result['position']}")
|
||||
print("=" * 60)
|
||||
Loading…
Reference in New Issue