416 lines
15 KiB
Python
416 lines
15 KiB
Python
"""
|
|
CMA Code Extraction Module using Full-Page OCR with Position Filtering
|
|
|
|
This module provides a robust method for extracting CMA certification codes
|
|
from PDF report pages using PaddleOCR with position-based filtering.
|
|
|
|
Approach:
|
|
1. Run full-page OCR to get all text with positions
|
|
2. Filter for text in top-right area (where CMA logo/code is typically located)
|
|
3. Use regex patterns to find CMA code (11 digits starting with '2')
|
|
4. Score candidates by: position, confidence, code format
|
|
|
|
Author: Based on reference implementation from refer/认监-扫描件识别
|
|
Date: 2025-02-05
|
|
"""
|
|
import os
|
|
import re
|
|
import cv2
|
|
import numpy as np
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# CMA code patterns
|
|
PATTERN_PRIMARY = r'2[0-9]{10,11}' # 11-12 digits starting with 2
|
|
PATTERN_FALLBACK = r'[0-9]{11,12}' # any 11-12 digits
|
|
|
|
|
|
def imread_unicode(path, flags=cv2.IMREAD_COLOR):
|
|
"""
|
|
cv2.imread replacement that supports paths with non-ASCII characters.
|
|
|
|
Args:
|
|
path: Image file path (may contain Chinese characters)
|
|
flags: cv2.IMREAD_* flags
|
|
|
|
Returns:
|
|
Image as numpy array or None if failed
|
|
"""
|
|
try:
|
|
data = np.fromfile(str(path), dtype=np.uint8)
|
|
img = cv2.imdecode(data, flags)
|
|
return img
|
|
except Exception as e:
|
|
logger.error(f"Failed to read image {path}: {e}")
|
|
return None
|
|
|
|
|
|
def score_cma_candidate(candidate, page_width, page_height):
|
|
"""
|
|
Score a CMA code candidate based on multiple factors.
|
|
|
|
Scoring factors:
|
|
- Code format: 11 digits starting with '2' gets highest score
|
|
- Position: Top-right area gets bonus
|
|
- Confidence: Higher OCR confidence increases score
|
|
- Length: Exact 11 digits gets bonus
|
|
|
|
Args:
|
|
candidate: Dict with 'code', 'confidence', 'position', 'text'
|
|
page_width: Page width in pixels
|
|
page_height: Page height in pixels
|
|
|
|
Returns:
|
|
Score (higher is better)
|
|
"""
|
|
score = 0
|
|
code = candidate['code']
|
|
confidence = candidate['confidence']
|
|
pos_x, pos_y = candidate['position']
|
|
|
|
# Format score: 11 digits starting with '2' is perfect
|
|
if len(code) == 11 and code.startswith('2'):
|
|
score += 100
|
|
elif code.startswith('2'):
|
|
score += 50
|
|
|
|
# Length bonus (11-12 digits are standard)
|
|
if len(code) in (11, 12):
|
|
score += 20
|
|
elif 10 <= len(code) <= 13:
|
|
score += 10
|
|
|
|
# Position bonus: prefer top-right area (typical CMA logo location)
|
|
if pos_x > page_width * 0.5 and pos_y < page_height * 0.35:
|
|
score += 30
|
|
|
|
# Confidence bonus (0-10)
|
|
score += confidence * 10
|
|
|
|
return score
|
|
|
|
|
|
def extract_cma_code_fullpage(page_img, ocr_engine, output_dir=None):
|
|
"""
|
|
Extract CMA code from a PDF page image using full-page OCR with position filtering.
|
|
|
|
This is the recommended method for CMA extraction as it's more robust than
|
|
template matching and works even when the CMA logo is degraded.
|
|
|
|
Args:
|
|
page_img: Page image (numpy array or path to image)
|
|
ocr_engine: Initialized PaddleOCR instance
|
|
output_dir: Optional directory to save debug visualizations
|
|
|
|
Returns:
|
|
Dict with keys:
|
|
- 'code': Extracted CMA code (str or None)
|
|
- 'confidence': OCR confidence (float)
|
|
- 'raw_text': Raw OCR text containing the code (str)
|
|
- 'position': (x, y) tuple of code position
|
|
- 'box': Bounding box [x1, y1, x2, y2]
|
|
- 'success': Boolean indicating successful extraction
|
|
"""
|
|
result = {
|
|
'code': None,
|
|
'confidence': 0.0,
|
|
'raw_text': '',
|
|
'position': (0, 0),
|
|
'box': None,
|
|
'success': False
|
|
}
|
|
|
|
# Load image if path provided
|
|
if isinstance(page_img, str):
|
|
image = imread_unicode(page_img, cv2.IMREAD_COLOR)
|
|
elif isinstance(page_img, np.ndarray):
|
|
image = page_img
|
|
else:
|
|
logger.error(f"Invalid image type: {type(page_img)}")
|
|
return result
|
|
|
|
if image is None or image.size == 0:
|
|
logger.error("Failed to load image or empty image")
|
|
return result
|
|
|
|
h, w = image.shape[:2]
|
|
|
|
# Run OCR
|
|
logger.info("Running full-page OCR for CMA extraction...")
|
|
try:
|
|
# Check for legacy PaddleOCR
|
|
if hasattr(ocr_engine, 'ocr'):
|
|
# Legacy PaddleOCR.ocr returns [ [ [box, (text, score)], ... ] ]
|
|
# Try simple .ocr() call first (without cls parameter for better compatibility)
|
|
try:
|
|
raw_result = ocr_engine.ocr(image)
|
|
except Exception as ocr_err:
|
|
logger.warning(f".ocr() method failed: {ocr_err}, trying .predict()...")
|
|
raw_result = None
|
|
|
|
# Fallback to .predict() if .ocr() failed
|
|
if hasattr(ocr_engine, 'predict'):
|
|
try:
|
|
raw_result = ocr_engine.predict(image)
|
|
except Exception as pred_err:
|
|
logger.error(f".predict() also failed: {pred_err}")
|
|
return result
|
|
|
|
if raw_result is None:
|
|
logger.error("OCR returned None")
|
|
return result
|
|
|
|
# Initialize lists
|
|
rec_texts = []
|
|
rec_scores = []
|
|
rec_boxes = []
|
|
|
|
# Validate raw_result structure
|
|
if not isinstance(raw_result, list):
|
|
logger.error(f"OCR returned unexpected type: {type(raw_result)}")
|
|
return result
|
|
|
|
if len(raw_result) == 0:
|
|
logger.warning("OCR returned empty list")
|
|
return result
|
|
|
|
if raw_result[0] is None:
|
|
logger.warning("OCR result[0] is None")
|
|
return result
|
|
|
|
if not isinstance(raw_result[0], list):
|
|
logger.error(f"OCR result[0] is not a list: {type(raw_result[0])}")
|
|
return result
|
|
|
|
for line_idx, line in enumerate(raw_result[0]):
|
|
# line: [box, (text, score)] or [box, text]
|
|
try:
|
|
if not isinstance(line, (list, tuple)) or len(line) < 2:
|
|
logger.debug(f"Skipping line {line_idx}: invalid format")
|
|
continue
|
|
|
|
box = line[0]
|
|
|
|
# Validate box before processing
|
|
if box is None:
|
|
logger.debug(f"Skipping line {line_idx}: box is None")
|
|
continue
|
|
|
|
# Extract text and score
|
|
if isinstance(line[1], (list, tuple)) and len(line[1]) >= 2:
|
|
text, score = line[1]
|
|
elif isinstance(line[1], (list, tuple)) and len(line[1]) == 1:
|
|
text = line[1][0]
|
|
score = 0.99
|
|
elif isinstance(line[1], str):
|
|
text = line[1]
|
|
score = 0.99
|
|
else:
|
|
text = str(line[1])
|
|
score = 0.99
|
|
|
|
rec_texts.append(text)
|
|
rec_scores.append(score)
|
|
|
|
# Convert box to [x1, y1, x2, y2] with proper validation
|
|
# box is [[x,y], [x,y], [x,y], [x,y]]
|
|
try:
|
|
if isinstance(box, (list, tuple)) and len(box) >= 4:
|
|
# Check if box is in [[x,y], [x,y], ...] format
|
|
if all(isinstance(pt, (list, tuple)) and len(pt) >= 2 for pt in box[:4]):
|
|
xs = [pt[0] for pt in box[:4]]
|
|
ys = [pt[1] for pt in box[:4]]
|
|
rec_boxes.append([min(xs), min(ys), max(xs), max(ys)])
|
|
else:
|
|
# Box might be in [x1, y1, x2, y2, ...] format
|
|
rec_boxes.append([box[0], box[1], box[2], box[3]])
|
|
else:
|
|
logger.warning(f"Line {line_idx}: Invalid box format: {type(box)}, len={len(box) if hasattr(box, '__len__') else 'N/A'}")
|
|
rec_boxes.append(None) # Placeholder to maintain alignment
|
|
except Exception as box_err:
|
|
logger.warning(f"Line {line_idx}: Failed to parse box: {box_err}")
|
|
rec_boxes.append(None) # Placeholder to maintain alignment
|
|
|
|
except (IndexError, TypeError, ValueError) as e:
|
|
logger.warning(f"Skipped malformed OCR line {line_idx}: {e}")
|
|
continue
|
|
|
|
logger.info(f"Found {len(rec_texts)} text lines (Legacy API)")
|
|
|
|
else:
|
|
# Assume Paddlex or similar API (PaddleOCR 2.7+)
|
|
ocr_result = ocr_engine.predict(image)
|
|
|
|
# Debug: Log result type
|
|
logger.debug(f"OCR result type: {type(ocr_result)}")
|
|
logger.debug(f"OCR result: {ocr_result}")
|
|
|
|
# Handle different return types
|
|
if ocr_result is None:
|
|
logger.error("OCR returned None")
|
|
return result
|
|
|
|
# Check if result is a list (old API) or single object (new API)
|
|
if isinstance(ocr_result, list):
|
|
if len(ocr_result) == 0:
|
|
logger.error("OCR returned empty list")
|
|
return result
|
|
ocr_data = ocr_result[0]
|
|
else:
|
|
# New API: result is already an OCRResult object
|
|
ocr_data = ocr_result
|
|
|
|
# Extract data from OCRResult object
|
|
# PaddleOCR 3.4+ uses nested OCRResult structure
|
|
# The object may behave differently on subsequent calls, so use dict-style access
|
|
|
|
# Try to get data as dictionary (most reliable method)
|
|
try:
|
|
if hasattr(ocr_data, 'keys'):
|
|
# Can use dict-like access
|
|
rec_texts = ocr_data.get('rec_texts', [])
|
|
rec_scores = ocr_data.get('rec_scores', [])
|
|
rec_boxes = ocr_data.get('rec_boxes', [])
|
|
elif hasattr(ocr_data, '__getitem__'):
|
|
# Try to access as dict
|
|
rec_texts = ocr_data['rec_texts'] if 'rec_texts' in ocr_data else []
|
|
rec_scores = ocr_data['rec_scores'] if 'rec_scores' in ocr_data else []
|
|
rec_boxes = ocr_data['rec_boxes'] if 'rec_boxes' in ocr_data else []
|
|
else:
|
|
# Fallback: try attribute access
|
|
rec_texts = getattr(ocr_data, 'rec_texts', [])
|
|
rec_scores = getattr(ocr_data, 'rec_scores', [])
|
|
rec_boxes = getattr(ocr_data, 'rec_boxes', [])
|
|
except Exception as parse_error:
|
|
logger.error(f"Failed to extract OCR data: {parse_error}")
|
|
logger.error(f"OCRResult type: {type(ocr_data)}")
|
|
logger.error(f"OCRResult dir: {[a for a in dir(ocr_data) if not a.startswith('_')]}")
|
|
return result
|
|
|
|
logger.debug(f"Extracted rec_texts: {rec_texts}")
|
|
logger.debug(f"Extracted rec_scores: {rec_scores}")
|
|
logger.debug(f"Extracted rec_boxes type: {type(rec_boxes)}")
|
|
|
|
logger.info(f"Found {len(rec_texts)} text lines")
|
|
|
|
except Exception as e:
|
|
logger.error(f"OCR failed: {e}")
|
|
return result
|
|
|
|
# Find CMA code candidates
|
|
cma_candidates = []
|
|
|
|
# Debug: Log the data we got
|
|
logger.debug(f"rec_texts type: {type(rec_texts)}, length: {len(rec_texts) if hasattr(rec_texts, '__len__') else 'N/A'}")
|
|
logger.debug(f"rec_scores type: {type(rec_scores)}, length: {len(rec_scores) if hasattr(rec_scores, '__len__') else 'N/A'}")
|
|
logger.debug(f"rec_boxes type: {type(rec_boxes)}, length: {len(rec_boxes) if hasattr(rec_boxes, '__len__') else 'N/A'}")
|
|
|
|
for i, text in enumerate(rec_texts):
|
|
# Ensure text is a string
|
|
if not isinstance(text, str):
|
|
logger.warning(f"Skip non-string text at index {i}: type={type(text)}, value={text}")
|
|
continue
|
|
|
|
if not text or len(text.strip()) == 0:
|
|
logger.debug(f"Skip empty text at index {i}")
|
|
continue
|
|
|
|
confidence = rec_scores[i] if i < len(rec_scores) else 0.5
|
|
box = rec_boxes[i] if i < len(rec_boxes) else None
|
|
|
|
# Calculate center position (box format: [x1, y1, x2, y2])
|
|
if box is not None and len(box) == 4:
|
|
x1, y1, x2, y2 = box
|
|
center_x = (x1 + x2) / 2
|
|
center_y = (y1 + y2) / 2
|
|
else:
|
|
center_x = center_y = 0
|
|
|
|
# Extract numbers from text (with error handling)
|
|
try:
|
|
cleaned_text = text.replace(" ", "").replace("-", "")
|
|
numbers = re.findall(r'[0-9]+', cleaned_text)
|
|
except Exception as e:
|
|
logger.error(f"Error processing text at index {i}: {e}, text='{text}', type={type(text)}")
|
|
continue
|
|
|
|
for num in numbers:
|
|
if 10 <= len(num) <= 12: # CMA codes are typically 11 digits
|
|
cma_candidates.append({
|
|
'code': num,
|
|
'confidence': confidence,
|
|
'text': text,
|
|
'position': (center_x, center_y),
|
|
'box': box,
|
|
'index': i
|
|
})
|
|
|
|
if not cma_candidates:
|
|
logger.warning("No CMA code candidates found")
|
|
return result
|
|
|
|
# Score and sort candidates
|
|
cma_candidates.sort(key=lambda c: score_cma_candidate(c, w, h), reverse=True)
|
|
best = cma_candidates[0]
|
|
|
|
# Update result
|
|
result['code'] = best['code']
|
|
result['confidence'] = best['confidence']
|
|
result['raw_text'] = best['text']
|
|
result['position'] = best['position']
|
|
result['box'] = best['box']
|
|
result['success'] = True
|
|
|
|
logger.info(f"Extracted CMA code: {best['code']} (confidence: {best['confidence']:.4f})")
|
|
|
|
# Visualize if output_dir provided
|
|
if output_dir and best['box'] is not None:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
viz = image.copy()
|
|
|
|
x1, y1, x2, y2 = [int(v) for v in best['box']]
|
|
cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3)
|
|
cv2.putText(viz, f"CMA: {best['code']}", (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
|
|
|
viz_path = os.path.join(output_dir, "cma_detection_fullpage.png")
|
|
cv2.imwrite(viz_path, viz)
|
|
logger.info(f"Saved visualization: {viz_path}")
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the CMA extraction
|
|
import sys
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python cma_extraction_final.py <image_path> [output_dir]")
|
|
sys.exit(1)
|
|
|
|
img_path = sys.argv[1]
|
|
out_dir = sys.argv[2] if len(sys.argv) > 2 else "cma_test_output"
|
|
|
|
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
|
from paddleocr import PaddleOCR
|
|
|
|
print("Initializing PaddleOCR...")
|
|
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
|
|
|
|
result = extract_cma_code_fullpage(img_path, ocr, out_dir)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("CMA EXTRACTION RESULT")
|
|
print("=" * 60)
|
|
print(f"Success: {result['success']}")
|
|
if result['success']:
|
|
print(f"CMA Code: {result['code']}")
|
|
print(f"Confidence: {result['confidence']:.4f}")
|
|
print(f"Raw Text: '{result['raw_text']}'")
|
|
print(f"Position: {result['position']}")
|
|
print("=" * 60)
|