report-detect/cma_extraction_final.py

416 lines
15 KiB
Python

"""
CMA Code Extraction Module using Full-Page OCR with Position Filtering
This module provides a robust method for extracting CMA certification codes
from PDF report pages using PaddleOCR with position-based filtering.
Approach:
1. Run full-page OCR to get all text with positions
2. Filter for text in top-right area (where CMA logo/code is typically located)
3. Use regex patterns to find CMA code (11 digits starting with '2')
4. Score candidates by: position, confidence, code format
Author: Based on reference implementation from refer/认监-扫描件识别
Date: 2025-02-05
"""
import os
import re
import cv2
import numpy as np
import json
import logging
logger = logging.getLogger(__name__)
# CMA code patterns
PATTERN_PRIMARY = r'2[0-9]{10,11}' # 11-12 digits starting with 2
PATTERN_FALLBACK = r'[0-9]{11,12}' # any 11-12 digits
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
"""
cv2.imread replacement that supports paths with non-ASCII characters.
Args:
path: Image file path (may contain Chinese characters)
flags: cv2.IMREAD_* flags
Returns:
Image as numpy array or None if failed
"""
try:
data = np.fromfile(str(path), dtype=np.uint8)
img = cv2.imdecode(data, flags)
return img
except Exception as e:
logger.error(f"Failed to read image {path}: {e}")
return None
def score_cma_candidate(candidate, page_width, page_height):
"""
Score a CMA code candidate based on multiple factors.
Scoring factors:
- Code format: 11 digits starting with '2' gets highest score
- Position: Top-right area gets bonus
- Confidence: Higher OCR confidence increases score
- Length: Exact 11 digits gets bonus
Args:
candidate: Dict with 'code', 'confidence', 'position', 'text'
page_width: Page width in pixels
page_height: Page height in pixels
Returns:
Score (higher is better)
"""
score = 0
code = candidate['code']
confidence = candidate['confidence']
pos_x, pos_y = candidate['position']
# Format score: 11 digits starting with '2' is perfect
if len(code) == 11 and code.startswith('2'):
score += 100
elif code.startswith('2'):
score += 50
# Length bonus (11-12 digits are standard)
if len(code) in (11, 12):
score += 20
elif 10 <= len(code) <= 13:
score += 10
# Position bonus: prefer top-right area (typical CMA logo location)
if pos_x > page_width * 0.5 and pos_y < page_height * 0.35:
score += 30
# Confidence bonus (0-10)
score += confidence * 10
return score
def extract_cma_code_fullpage(page_img, ocr_engine, output_dir=None):
"""
Extract CMA code from a PDF page image using full-page OCR with position filtering.
This is the recommended method for CMA extraction as it's more robust than
template matching and works even when the CMA logo is degraded.
Args:
page_img: Page image (numpy array or path to image)
ocr_engine: Initialized PaddleOCR instance
output_dir: Optional directory to save debug visualizations
Returns:
Dict with keys:
- 'code': Extracted CMA code (str or None)
- 'confidence': OCR confidence (float)
- 'raw_text': Raw OCR text containing the code (str)
- 'position': (x, y) tuple of code position
- 'box': Bounding box [x1, y1, x2, y2]
- 'success': Boolean indicating successful extraction
"""
result = {
'code': None,
'confidence': 0.0,
'raw_text': '',
'position': (0, 0),
'box': None,
'success': False
}
# Load image if path provided
if isinstance(page_img, str):
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
elif isinstance(page_img, np.ndarray):
image = page_img
else:
logger.error(f"Invalid image type: {type(page_img)}")
return result
if image is None or image.size == 0:
logger.error("Failed to load image or empty image")
return result
h, w = image.shape[:2]
# Run OCR
logger.info("Running full-page OCR for CMA extraction...")
try:
# Check for legacy PaddleOCR
if hasattr(ocr_engine, 'ocr'):
# Legacy PaddleOCR.ocr returns [ [ [box, (text, score)], ... ] ]
# Try simple .ocr() call first (without cls parameter for better compatibility)
try:
raw_result = ocr_engine.ocr(image)
except Exception as ocr_err:
logger.warning(f".ocr() method failed: {ocr_err}, trying .predict()...")
raw_result = None
# Fallback to .predict() if .ocr() failed
if hasattr(ocr_engine, 'predict'):
try:
raw_result = ocr_engine.predict(image)
except Exception as pred_err:
logger.error(f".predict() also failed: {pred_err}")
return result
if raw_result is None:
logger.error("OCR returned None")
return result
# Initialize lists
rec_texts = []
rec_scores = []
rec_boxes = []
# Validate raw_result structure
if not isinstance(raw_result, list):
logger.error(f"OCR returned unexpected type: {type(raw_result)}")
return result
if len(raw_result) == 0:
logger.warning("OCR returned empty list")
return result
if raw_result[0] is None:
logger.warning("OCR result[0] is None")
return result
if not isinstance(raw_result[0], list):
logger.error(f"OCR result[0] is not a list: {type(raw_result[0])}")
return result
for line_idx, line in enumerate(raw_result[0]):
# line: [box, (text, score)] or [box, text]
try:
if not isinstance(line, (list, tuple)) or len(line) < 2:
logger.debug(f"Skipping line {line_idx}: invalid format")
continue
box = line[0]
# Validate box before processing
if box is None:
logger.debug(f"Skipping line {line_idx}: box is None")
continue
# Extract text and score
if isinstance(line[1], (list, tuple)) and len(line[1]) >= 2:
text, score = line[1]
elif isinstance(line[1], (list, tuple)) and len(line[1]) == 1:
text = line[1][0]
score = 0.99
elif isinstance(line[1], str):
text = line[1]
score = 0.99
else:
text = str(line[1])
score = 0.99
rec_texts.append(text)
rec_scores.append(score)
# Convert box to [x1, y1, x2, y2] with proper validation
# box is [[x,y], [x,y], [x,y], [x,y]]
try:
if isinstance(box, (list, tuple)) and len(box) >= 4:
# Check if box is in [[x,y], [x,y], ...] format
if all(isinstance(pt, (list, tuple)) and len(pt) >= 2 for pt in box[:4]):
xs = [pt[0] for pt in box[:4]]
ys = [pt[1] for pt in box[:4]]
rec_boxes.append([min(xs), min(ys), max(xs), max(ys)])
else:
# Box might be in [x1, y1, x2, y2, ...] format
rec_boxes.append([box[0], box[1], box[2], box[3]])
else:
logger.warning(f"Line {line_idx}: Invalid box format: {type(box)}, len={len(box) if hasattr(box, '__len__') else 'N/A'}")
rec_boxes.append(None) # Placeholder to maintain alignment
except Exception as box_err:
logger.warning(f"Line {line_idx}: Failed to parse box: {box_err}")
rec_boxes.append(None) # Placeholder to maintain alignment
except (IndexError, TypeError, ValueError) as e:
logger.warning(f"Skipped malformed OCR line {line_idx}: {e}")
continue
logger.info(f"Found {len(rec_texts)} text lines (Legacy API)")
else:
# Assume Paddlex or similar API (PaddleOCR 2.7+)
ocr_result = ocr_engine.predict(image)
# Debug: Log result type
logger.debug(f"OCR result type: {type(ocr_result)}")
logger.debug(f"OCR result: {ocr_result}")
# Handle different return types
if ocr_result is None:
logger.error("OCR returned None")
return result
# Check if result is a list (old API) or single object (new API)
if isinstance(ocr_result, list):
if len(ocr_result) == 0:
logger.error("OCR returned empty list")
return result
ocr_data = ocr_result[0]
else:
# New API: result is already an OCRResult object
ocr_data = ocr_result
# Extract data from OCRResult object
# PaddleOCR 3.4+ uses nested OCRResult structure
# The object may behave differently on subsequent calls, so use dict-style access
# Try to get data as dictionary (most reliable method)
try:
if hasattr(ocr_data, 'keys'):
# Can use dict-like access
rec_texts = ocr_data.get('rec_texts', [])
rec_scores = ocr_data.get('rec_scores', [])
rec_boxes = ocr_data.get('rec_boxes', [])
elif hasattr(ocr_data, '__getitem__'):
# Try to access as dict
rec_texts = ocr_data['rec_texts'] if 'rec_texts' in ocr_data else []
rec_scores = ocr_data['rec_scores'] if 'rec_scores' in ocr_data else []
rec_boxes = ocr_data['rec_boxes'] if 'rec_boxes' in ocr_data else []
else:
# Fallback: try attribute access
rec_texts = getattr(ocr_data, 'rec_texts', [])
rec_scores = getattr(ocr_data, 'rec_scores', [])
rec_boxes = getattr(ocr_data, 'rec_boxes', [])
except Exception as parse_error:
logger.error(f"Failed to extract OCR data: {parse_error}")
logger.error(f"OCRResult type: {type(ocr_data)}")
logger.error(f"OCRResult dir: {[a for a in dir(ocr_data) if not a.startswith('_')]}")
return result
logger.debug(f"Extracted rec_texts: {rec_texts}")
logger.debug(f"Extracted rec_scores: {rec_scores}")
logger.debug(f"Extracted rec_boxes type: {type(rec_boxes)}")
logger.info(f"Found {len(rec_texts)} text lines")
except Exception as e:
logger.error(f"OCR failed: {e}")
return result
# Find CMA code candidates
cma_candidates = []
# Debug: Log the data we got
logger.debug(f"rec_texts type: {type(rec_texts)}, length: {len(rec_texts) if hasattr(rec_texts, '__len__') else 'N/A'}")
logger.debug(f"rec_scores type: {type(rec_scores)}, length: {len(rec_scores) if hasattr(rec_scores, '__len__') else 'N/A'}")
logger.debug(f"rec_boxes type: {type(rec_boxes)}, length: {len(rec_boxes) if hasattr(rec_boxes, '__len__') else 'N/A'}")
for i, text in enumerate(rec_texts):
# Ensure text is a string
if not isinstance(text, str):
logger.warning(f"Skip non-string text at index {i}: type={type(text)}, value={text}")
continue
if not text or len(text.strip()) == 0:
logger.debug(f"Skip empty text at index {i}")
continue
confidence = rec_scores[i] if i < len(rec_scores) else 0.5
box = rec_boxes[i] if i < len(rec_boxes) else None
# Calculate center position (box format: [x1, y1, x2, y2])
if box is not None and len(box) == 4:
x1, y1, x2, y2 = box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
else:
center_x = center_y = 0
# Extract numbers from text (with error handling)
try:
cleaned_text = text.replace(" ", "").replace("-", "")
numbers = re.findall(r'[0-9]+', cleaned_text)
except Exception as e:
logger.error(f"Error processing text at index {i}: {e}, text='{text}', type={type(text)}")
continue
for num in numbers:
if 10 <= len(num) <= 12: # CMA codes are typically 11 digits
cma_candidates.append({
'code': num,
'confidence': confidence,
'text': text,
'position': (center_x, center_y),
'box': box,
'index': i
})
if not cma_candidates:
logger.warning("No CMA code candidates found")
return result
# Score and sort candidates
cma_candidates.sort(key=lambda c: score_cma_candidate(c, w, h), reverse=True)
best = cma_candidates[0]
# Update result
result['code'] = best['code']
result['confidence'] = best['confidence']
result['raw_text'] = best['text']
result['position'] = best['position']
result['box'] = best['box']
result['success'] = True
logger.info(f"Extracted CMA code: {best['code']} (confidence: {best['confidence']:.4f})")
# Visualize if output_dir provided
if output_dir and best['box'] is not None:
os.makedirs(output_dir, exist_ok=True)
viz = image.copy()
x1, y1, x2, y2 = [int(v) for v in best['box']]
cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3)
cv2.putText(viz, f"CMA: {best['code']}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
viz_path = os.path.join(output_dir, "cma_detection_fullpage.png")
cv2.imwrite(viz_path, viz)
logger.info(f"Saved visualization: {viz_path}")
return result
if __name__ == "__main__":
# Test the CMA extraction
import sys
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
if len(sys.argv) < 2:
print("Usage: python cma_extraction_final.py <image_path> [output_dir]")
sys.exit(1)
img_path = sys.argv[1]
out_dir = sys.argv[2] if len(sys.argv) > 2 else "cma_test_output"
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
from paddleocr import PaddleOCR
print("Initializing PaddleOCR...")
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
result = extract_cma_code_fullpage(img_path, ocr, out_dir)
print("\n" + "=" * 60)
print("CMA EXTRACTION RESULT")
print("=" * 60)
print(f"Success: {result['success']}")
if result['success']:
print(f"CMA Code: {result['code']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Raw Text: '{result['raw_text']}'")
print(f"Position: {result['position']}")
print("=" * 60)