report-detect/scripts/layout_viz.py

96 lines
3.1 KiB
Python
Raw Permalink Normal View History

2026-02-05 13:57:22 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Layout Detection with Visualization
Outputs annotated image with bounding boxes for all detected regions
"""
import sys
import json
import os
from PIL import Image, ImageDraw, ImageFont
def main():
if len(sys.argv) < 3:
print(json.dumps({"error": "Usage: python layout_viz.py <image_path> <output_path>"}))
sys.exit(1)
image_path = sys.argv[1]
output_path = sys.argv[2]
# Color mapping for different layout classes
COLORS = {
"seal": (255, 0, 255), # Magenta for seal
"text": (0, 128, 0), # Green for text
"paragraph_title": (255, 165, 0), # Orange for paragraph title
"doc_title": (255, 0, 0), # Red for doc title
"table": (0, 0, 255), # Blue for table
"image": (128, 0, 128), # Purple for image
"formula": (0, 255, 255), # Cyan for formula
"chart": (255, 255, 0), # Yellow for chart
"header": (128, 128, 128), # Gray for header
"footer": (128, 128, 128), # Gray for footer
}
DEFAULT_COLOR = (100, 100, 100)
try:
from paddleocr import LayoutDetection
# Initialize model
model = LayoutDetection(model_name="PP-DocLayout_plus-L")
output = model.predict(image_path, batch_size=1, layout_nms=True)
# Load image for drawing
img = Image.open(image_path)
draw = ImageDraw.Draw(img)
# Try to use a font, fallback to default
try:
font = ImageFont.truetype("arial.ttf", 20)
except:
font = ImageFont.load_default()
results = []
for res in output:
for box in res["boxes"]:
x1, y1, x2, y2 = box["coordinate"]
label = box["label"]
score = box["score"]
results.append({
"label": label,
"score": float(score),
"x1": float(x1),
"y1": float(y1),
"x2": float(x2),
"y2": float(y2)
})
# Get color for this label
color = COLORS.get(label, DEFAULT_COLOR)
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Draw label with background
label_text = f"{label}: {score:.2f}"
bbox = draw.textbbox((x1, y1 - 25), label_text, font=font)
draw.rectangle(bbox, fill=color)
draw.text((x1, y1 - 25), label_text, fill=(255, 255, 255), font=font)
# Save annotated image
img.save(output_path)
print(json.dumps({
"success": True,
"output_path": output_path,
"detections": results
}))
except Exception as e:
import traceback
print(json.dumps({"error": str(e), "traceback": traceback.format_exc()}))
sys.exit(1)
if __name__ == "__main__":
main()