diff --git a/cma_extraction_final.py b/cma_extraction_final.py new file mode 100644 index 0000000..0500163 --- /dev/null +++ b/cma_extraction_final.py @@ -0,0 +1,415 @@ +""" +CMA Code Extraction Module using Full-Page OCR with Position Filtering + +This module provides a robust method for extracting CMA certification codes +from PDF report pages using PaddleOCR with position-based filtering. + +Approach: +1. Run full-page OCR to get all text with positions +2. Filter for text in top-right area (where CMA logo/code is typically located) +3. Use regex patterns to find CMA code (11 digits starting with '2') +4. Score candidates by: position, confidence, code format + +Author: Based on reference implementation from refer/认监-扫描件识别 +Date: 2025-02-05 +""" +import os +import re +import cv2 +import numpy as np +import json +import logging + +logger = logging.getLogger(__name__) + +# CMA code patterns +PATTERN_PRIMARY = r'2[0-9]{10,11}' # 11-12 digits starting with 2 +PATTERN_FALLBACK = r'[0-9]{11,12}' # any 11-12 digits + + +def imread_unicode(path, flags=cv2.IMREAD_COLOR): + """ + cv2.imread replacement that supports paths with non-ASCII characters. + + Args: + path: Image file path (may contain Chinese characters) + flags: cv2.IMREAD_* flags + + Returns: + Image as numpy array or None if failed + """ + try: + data = np.fromfile(str(path), dtype=np.uint8) + img = cv2.imdecode(data, flags) + return img + except Exception as e: + logger.error(f"Failed to read image {path}: {e}") + return None + + +def score_cma_candidate(candidate, page_width, page_height): + """ + Score a CMA code candidate based on multiple factors. + + Scoring factors: + - Code format: 11 digits starting with '2' gets highest score + - Position: Top-right area gets bonus + - Confidence: Higher OCR confidence increases score + - Length: Exact 11 digits gets bonus + + Args: + candidate: Dict with 'code', 'confidence', 'position', 'text' + page_width: Page width in pixels + page_height: Page height in pixels + + Returns: + Score (higher is better) + """ + score = 0 + code = candidate['code'] + confidence = candidate['confidence'] + pos_x, pos_y = candidate['position'] + + # Format score: 11 digits starting with '2' is perfect + if len(code) == 11 and code.startswith('2'): + score += 100 + elif code.startswith('2'): + score += 50 + + # Length bonus (11-12 digits are standard) + if len(code) in (11, 12): + score += 20 + elif 10 <= len(code) <= 13: + score += 10 + + # Position bonus: prefer top-right area (typical CMA logo location) + if pos_x > page_width * 0.5 and pos_y < page_height * 0.35: + score += 30 + + # Confidence bonus (0-10) + score += confidence * 10 + + return score + + +def extract_cma_code_fullpage(page_img, ocr_engine, output_dir=None): + """ + Extract CMA code from a PDF page image using full-page OCR with position filtering. + + This is the recommended method for CMA extraction as it's more robust than + template matching and works even when the CMA logo is degraded. + + Args: + page_img: Page image (numpy array or path to image) + ocr_engine: Initialized PaddleOCR instance + output_dir: Optional directory to save debug visualizations + + Returns: + Dict with keys: + - 'code': Extracted CMA code (str or None) + - 'confidence': OCR confidence (float) + - 'raw_text': Raw OCR text containing the code (str) + - 'position': (x, y) tuple of code position + - 'box': Bounding box [x1, y1, x2, y2] + - 'success': Boolean indicating successful extraction + """ + result = { + 'code': None, + 'confidence': 0.0, + 'raw_text': '', + 'position': (0, 0), + 'box': None, + 'success': False + } + + # Load image if path provided + if isinstance(page_img, str): + image = imread_unicode(page_img, cv2.IMREAD_COLOR) + elif isinstance(page_img, np.ndarray): + image = page_img + else: + logger.error(f"Invalid image type: {type(page_img)}") + return result + + if image is None or image.size == 0: + logger.error("Failed to load image or empty image") + return result + + h, w = image.shape[:2] + + # Run OCR + logger.info("Running full-page OCR for CMA extraction...") + try: + # Check for legacy PaddleOCR + if hasattr(ocr_engine, 'ocr'): + # Legacy PaddleOCR.ocr returns [ [ [box, (text, score)], ... ] ] + # Try simple .ocr() call first (without cls parameter for better compatibility) + try: + raw_result = ocr_engine.ocr(image) + except Exception as ocr_err: + logger.warning(f".ocr() method failed: {ocr_err}, trying .predict()...") + raw_result = None + + # Fallback to .predict() if .ocr() failed + if hasattr(ocr_engine, 'predict'): + try: + raw_result = ocr_engine.predict(image) + except Exception as pred_err: + logger.error(f".predict() also failed: {pred_err}") + return result + + if raw_result is None: + logger.error("OCR returned None") + return result + + # Initialize lists + rec_texts = [] + rec_scores = [] + rec_boxes = [] + + # Validate raw_result structure + if not isinstance(raw_result, list): + logger.error(f"OCR returned unexpected type: {type(raw_result)}") + return result + + if len(raw_result) == 0: + logger.warning("OCR returned empty list") + return result + + if raw_result[0] is None: + logger.warning("OCR result[0] is None") + return result + + if not isinstance(raw_result[0], list): + logger.error(f"OCR result[0] is not a list: {type(raw_result[0])}") + return result + + for line_idx, line in enumerate(raw_result[0]): + # line: [box, (text, score)] or [box, text] + try: + if not isinstance(line, (list, tuple)) or len(line) < 2: + logger.debug(f"Skipping line {line_idx}: invalid format") + continue + + box = line[0] + + # Validate box before processing + if box is None: + logger.debug(f"Skipping line {line_idx}: box is None") + continue + + # Extract text and score + if isinstance(line[1], (list, tuple)) and len(line[1]) >= 2: + text, score = line[1] + elif isinstance(line[1], (list, tuple)) and len(line[1]) == 1: + text = line[1][0] + score = 0.99 + elif isinstance(line[1], str): + text = line[1] + score = 0.99 + else: + text = str(line[1]) + score = 0.99 + + rec_texts.append(text) + rec_scores.append(score) + + # Convert box to [x1, y1, x2, y2] with proper validation + # box is [[x,y], [x,y], [x,y], [x,y]] + try: + if isinstance(box, (list, tuple)) and len(box) >= 4: + # Check if box is in [[x,y], [x,y], ...] format + if all(isinstance(pt, (list, tuple)) and len(pt) >= 2 for pt in box[:4]): + xs = [pt[0] for pt in box[:4]] + ys = [pt[1] for pt in box[:4]] + rec_boxes.append([min(xs), min(ys), max(xs), max(ys)]) + else: + # Box might be in [x1, y1, x2, y2, ...] format + rec_boxes.append([box[0], box[1], box[2], box[3]]) + else: + logger.warning(f"Line {line_idx}: Invalid box format: {type(box)}, len={len(box) if hasattr(box, '__len__') else 'N/A'}") + rec_boxes.append(None) # Placeholder to maintain alignment + except Exception as box_err: + logger.warning(f"Line {line_idx}: Failed to parse box: {box_err}") + rec_boxes.append(None) # Placeholder to maintain alignment + + except (IndexError, TypeError, ValueError) as e: + logger.warning(f"Skipped malformed OCR line {line_idx}: {e}") + continue + + logger.info(f"Found {len(rec_texts)} text lines (Legacy API)") + + else: + # Assume Paddlex or similar API (PaddleOCR 2.7+) + ocr_result = ocr_engine.predict(image) + + # Debug: Log result type + logger.debug(f"OCR result type: {type(ocr_result)}") + logger.debug(f"OCR result: {ocr_result}") + + # Handle different return types + if ocr_result is None: + logger.error("OCR returned None") + return result + + # Check if result is a list (old API) or single object (new API) + if isinstance(ocr_result, list): + if len(ocr_result) == 0: + logger.error("OCR returned empty list") + return result + ocr_data = ocr_result[0] + else: + # New API: result is already an OCRResult object + ocr_data = ocr_result + + # Extract data from OCRResult object + # PaddleOCR 3.4+ uses nested OCRResult structure + # The object may behave differently on subsequent calls, so use dict-style access + + # Try to get data as dictionary (most reliable method) + try: + if hasattr(ocr_data, 'keys'): + # Can use dict-like access + rec_texts = ocr_data.get('rec_texts', []) + rec_scores = ocr_data.get('rec_scores', []) + rec_boxes = ocr_data.get('rec_boxes', []) + elif hasattr(ocr_data, '__getitem__'): + # Try to access as dict + rec_texts = ocr_data['rec_texts'] if 'rec_texts' in ocr_data else [] + rec_scores = ocr_data['rec_scores'] if 'rec_scores' in ocr_data else [] + rec_boxes = ocr_data['rec_boxes'] if 'rec_boxes' in ocr_data else [] + else: + # Fallback: try attribute access + rec_texts = getattr(ocr_data, 'rec_texts', []) + rec_scores = getattr(ocr_data, 'rec_scores', []) + rec_boxes = getattr(ocr_data, 'rec_boxes', []) + except Exception as parse_error: + logger.error(f"Failed to extract OCR data: {parse_error}") + logger.error(f"OCRResult type: {type(ocr_data)}") + logger.error(f"OCRResult dir: {[a for a in dir(ocr_data) if not a.startswith('_')]}") + return result + + logger.debug(f"Extracted rec_texts: {rec_texts}") + logger.debug(f"Extracted rec_scores: {rec_scores}") + logger.debug(f"Extracted rec_boxes type: {type(rec_boxes)}") + + logger.info(f"Found {len(rec_texts)} text lines") + + except Exception as e: + logger.error(f"OCR failed: {e}") + return result + + # Find CMA code candidates + cma_candidates = [] + + # Debug: Log the data we got + logger.debug(f"rec_texts type: {type(rec_texts)}, length: {len(rec_texts) if hasattr(rec_texts, '__len__') else 'N/A'}") + logger.debug(f"rec_scores type: {type(rec_scores)}, length: {len(rec_scores) if hasattr(rec_scores, '__len__') else 'N/A'}") + logger.debug(f"rec_boxes type: {type(rec_boxes)}, length: {len(rec_boxes) if hasattr(rec_boxes, '__len__') else 'N/A'}") + + for i, text in enumerate(rec_texts): + # Ensure text is a string + if not isinstance(text, str): + logger.warning(f"Skip non-string text at index {i}: type={type(text)}, value={text}") + continue + + if not text or len(text.strip()) == 0: + logger.debug(f"Skip empty text at index {i}") + continue + + confidence = rec_scores[i] if i < len(rec_scores) else 0.5 + box = rec_boxes[i] if i < len(rec_boxes) else None + + # Calculate center position (box format: [x1, y1, x2, y2]) + if box is not None and len(box) == 4: + x1, y1, x2, y2 = box + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + else: + center_x = center_y = 0 + + # Extract numbers from text (with error handling) + try: + cleaned_text = text.replace(" ", "").replace("-", "") + numbers = re.findall(r'[0-9]+', cleaned_text) + except Exception as e: + logger.error(f"Error processing text at index {i}: {e}, text='{text}', type={type(text)}") + continue + + for num in numbers: + if 10 <= len(num) <= 12: # CMA codes are typically 11 digits + cma_candidates.append({ + 'code': num, + 'confidence': confidence, + 'text': text, + 'position': (center_x, center_y), + 'box': box, + 'index': i + }) + + if not cma_candidates: + logger.warning("No CMA code candidates found") + return result + + # Score and sort candidates + cma_candidates.sort(key=lambda c: score_cma_candidate(c, w, h), reverse=True) + best = cma_candidates[0] + + # Update result + result['code'] = best['code'] + result['confidence'] = best['confidence'] + result['raw_text'] = best['text'] + result['position'] = best['position'] + result['box'] = best['box'] + result['success'] = True + + logger.info(f"Extracted CMA code: {best['code']} (confidence: {best['confidence']:.4f})") + + # Visualize if output_dir provided + if output_dir and best['box'] is not None: + os.makedirs(output_dir, exist_ok=True) + viz = image.copy() + + x1, y1, x2, y2 = [int(v) for v in best['box']] + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText(viz, f"CMA: {best['code']}", (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) + + viz_path = os.path.join(output_dir, "cma_detection_fullpage.png") + cv2.imwrite(viz_path, viz) + logger.info(f"Saved visualization: {viz_path}") + + return result + + +if __name__ == "__main__": + # Test the CMA extraction + import sys + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + if len(sys.argv) < 2: + print("Usage: python cma_extraction_final.py [output_dir]") + sys.exit(1) + + img_path = sys.argv[1] + out_dir = sys.argv[2] if len(sys.argv) > 2 else "cma_test_output" + + os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" + from paddleocr import PaddleOCR + + print("Initializing PaddleOCR...") + ocr = PaddleOCR(use_angle_cls=True, lang='ch') + + result = extract_cma_code_fullpage(img_path, ocr, out_dir) + + print("\n" + "=" * 60) + print("CMA EXTRACTION RESULT") + print("=" * 60) + print(f"Success: {result['success']}") + if result['success']: + print(f"CMA Code: {result['code']}") + print(f"Confidence: {result['confidence']:.4f}") + print(f"Raw Text: '{result['raw_text']}'") + print(f"Position: {result['position']}") + print("=" * 60)