From 17ef9cc83665b19d5e1d187c27dd109a0b8e91fc Mon Sep 17 00:00:00 2001 From: lroyia Date: Wed, 25 Feb 2026 15:55:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20GLM-OCR=20?= =?UTF-8?q?=E8=AF=86=E5=88=AB=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 OCR 服务模块支持文字/表格/公式识别 - 添加 OCR API 路由(文件上传和 base64 方式) - 更新配置以支持 GLM-OCR 模型 - 添加必要的依赖项(torch, transformers, accelerate) --- .env.example | 8 +- requirements.txt | 8 +- src/api/__init__.py | 5 +- src/api/ocr.py | 191 +++++++++++++++++++++ src/core/config.py | 6 +- src/main.py | 9 +- src/services/__init__.py | 5 +- src/services/ocr_service.py | 325 ++++++++++++++++++++++++++++++++++++ 8 files changed, 550 insertions(+), 7 deletions(-) create mode 100644 src/api/ocr.py create mode 100644 src/services/ocr_service.py diff --git a/.env.example b/.env.example index 3bbe1b3..c165d3d 100644 --- a/.env.example +++ b/.env.example @@ -11,4 +11,10 @@ PORT=8000 DATABASE_URL=sqlite:///./app.db # Security -SECRET_KEY=your-secret-key-here-change-in-production \ No newline at end of file +SECRET_KEY=your-secret-key-here-change-in-production + +# GLM-OCR Model +# 模型路径: 可以是 Hugging Face 模型 ID (如 zai-org/GLM-OCR) +# 或本地路径 (如 /path/to/local/model) +MODEL_PATH=zai-org/GLM-OCR +MODEL_DEVICE=auto diff --git a/requirements.txt b/requirements.txt index 5487565..48e4dd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,10 @@ python-multipart==0.0.12 httpx==0.27.2 # Utils -orjson==3.10.12 \ No newline at end of file +orjson==3.10.12 + +# GLM-OCR Dependencies +torch>=2.0.0 +transformers>=4.40.0 +Pillow>=10.0.0 +accelerate>=0.30.0 diff --git a/src/api/__init__.py b/src/api/__init__.py index 699e93e..6464232 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -1 +1,4 @@ -# api package \ No newline at end of file +# api package +from src.api.ocr import router as ocr_router + +__all__ = ["ocr_router"] diff --git a/src/api/ocr.py b/src/api/ocr.py new file mode 100644 index 0000000..a18ef5f --- /dev/null +++ b/src/api/ocr.py @@ -0,0 +1,191 @@ +"""OCR API endpoints.""" + +import base64 +import logging +from enum import Enum +from io import BytesIO +from typing import Optional + +from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from src.services.ocr_service import get_ocr_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ocr", tags=["OCR"]) + + +class OCRTType(str, Enum): + """OCR 识别类型枚举。""" + text = "text" + table = "table" + formula = "formula" + + +class OCRRequest(BaseModel): + """OCR 请求模型(用于 base64 图像)。""" + image_base64: str + ocr_type: OCRTType = OCRTType.text + max_new_tokens: int = 4096 + + +class OCRResponse(BaseModel): + """OCR 响应模型。""" + success: bool + text: Optional[str] = None + error: Optional[str] = None + + +@router.post("/recognize", response_model=OCRResponse) +async def recognize_text( + file: UploadFile = File(..., description="要识别的图像文件"), + max_new_tokens: int = Form(default=4096, description="最大生成 token 数") +): + """ + 识别图像中的文字内容。 + + 支持的图片格式: PNG, JPG, JPEG, WEBP + + Args: + file: 上传的图像文件 + max_new_tokens: 最大生成 token 数 + + Returns: + OCRResponse: 包含识别结果 + """ + try: + # 读取图像 + image_bytes = await file.read() + + # 调用 OCR 服务 + ocr_service = get_ocr_service() + result = ocr_service.recognize(image_bytes, max_new_tokens=max_new_tokens) + + return OCRResponse(**result) + + except Exception as e: + logger.error(f"文字识别失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/recognize/table", response_model=OCRResponse) +async def recognize_table( + file: UploadFile = File(..., description="包含表格的图像文件"), + max_new_tokens: int = Form(default=4096, description="最大生成 token 数") +): + """ + 识别图像中的表格内容。 + + 以 Markdown 表格格式输出识别结果。 + + Args: + file: 上传的图像文件 + max_new_tokens: 最大生成 token 数 + + Returns: + OCRResponse: 包含表格识别结果 + """ + try: + image_bytes = await file.read() + + ocr_service = get_ocr_service() + result = ocr_service.recognize_table(image_bytes, max_new_tokens=max_new_tokens) + + return OCRResponse(**result) + + except Exception as e: + logger.error(f"表格识别失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/recognize/formula", response_model=OCRResponse) +async def recognize_formula( + file: UploadFile = File(..., description="包含数学公式的图像文件"), + max_new_tokens: int = Form(default=2048, description="最大生成 token 数") +): + """ + 识别图像中的数学公式。 + + 以 LaTeX 格式输出识别结果。 + + Args: + file: 上传的图像文件 + max_new_tokens: 最大生成 token 数 + + Returns: + OCRResponse: 包含公式识别结果 + """ + try: + image_bytes = await file.read() + + ocr_service = get_ocr_service() + result = ocr_service.recognize_formula(image_bytes, max_new_tokens=max_new_tokens) + + return OCRResponse(**result) + + except Exception as e: + logger.error(f"公式识别失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/recognize/base64", response_model=OCRResponse) +async def recognize_base64(request: OCRRequest): + """ + 通过 base64 编码的图像进行 OCR 识别。 + + 支持纯 base64 字符串或 data:image/xxx;base64,... 格式。 + + Args: + request: OCRRequest 包含 base64 图像和识别类型 + + Returns: + OCRResponse: 包含识别结果 + """ + try: + ocr_service = get_ocr_service() + + if request.ocr_type == OCRTType.table: + result = ocr_service.recognize_table( + request.image_base64, + max_new_tokens=request.max_new_tokens + ) + elif request.ocr_type == OCRTType.formula: + result = ocr_service.recognize_formula( + request.image_base64, + max_new_tokens=request.max_new_tokens + ) + else: + result = ocr_service.recognize( + request.image_base64, + max_new_tokens=request.max_new_tokens + ) + + return OCRResponse(**result) + + except Exception as e: + logger.error(f"OCR 识别失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/status") +async def get_status(): + """ + 获取 OCR 服务状态。 + + Returns: + 服务状态信息 + """ + try: + ocr_service = get_ocr_service() + return { + "status": "ready", + "model_loaded": ocr_service._model is not None, + "tokenizer_loaded": ocr_service._tokenizer is not None + } + except Exception as e: + return { + "status": "error", + "error": str(e) + } diff --git a/src/core/config.py b/src/core/config.py index 372d24d..3351254 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -13,9 +13,13 @@ class Settings(BaseSettings): secret_key: str = "your-secret-key-here" + # GLM-OCR Model Settings + model_path: str = "zai-org/GLM-OCR" # Hugging Face 模型 ID 或本地路径 + model_device: str = "auto" # auto, cuda, cpu + class Config: env_file = ".env" case_sensitive = False -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/src/main.py b/src/main.py index 50af5f2..543ae1f 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from src.api.ocr import router as ocr_router from src.core.config import settings app = FastAPI( @@ -18,13 +19,17 @@ app.add_middleware( allow_headers=["*"], ) +# 注册 OCR 路由 +app.include_router(ocr_router) + @app.get("/") async def root(): return { "app": settings.app_name, "version": settings.app_version, - "status": "running" + "status": "running", + "model": settings.model_path } @@ -41,4 +46,4 @@ if __name__ == "__main__": host=settings.host, port=settings.port, reload=settings.debug - ) \ No newline at end of file + ) diff --git a/src/services/__init__.py b/src/services/__init__.py index 7e29b49..2632b67 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1 +1,4 @@ -# services package \ No newline at end of file +# services package +from src.services.ocr_service import OCRService, get_ocr_service + +__all__ = ["OCRService", "get_ocr_service"] diff --git a/src/services/ocr_service.py b/src/services/ocr_service.py new file mode 100644 index 0000000..31ede48 --- /dev/null +++ b/src/services/ocr_service.py @@ -0,0 +1,325 @@ +"""OCR service module for GLM-OCR model integration.""" + +import base64 +import io +import logging +from typing import Optional, Union + +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer + +from src.core.config import settings + +logger = logging.getLogger(__name__) + + +class OCRService: + """GLM-OCR 服务类,负责模型加载和推理。""" + + _instance: Optional["OCRService"] = None + _model = None + _tokenizer = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if self._model is None: + self._load_model() + + def _load_model(self): + """加载 GLM-OCR 模型和 tokenizer。""" + logger.info(f"正在加载 GLM-OCR 模型: {settings.model_path}") + + try: + # 加载 tokenizer + self._tokenizer = AutoTokenizer.from_pretrained( + settings.model_path, + trust_remote_code=True + ) + + # 加载模型 + self._model = AutoModelForCausalLM.from_pretrained( + settings.model_path, + torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, + device_map="auto" if torch.cuda.is_available() else None, + trust_remote_code=True + ) + + # 设置为评估模式 + self._model.eval() + + logger.info("GLM-OCR 模型加载完成") + + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + raise + + def _process_image(self, image_input: Union[str, bytes, Image.Image]) -> Image.Image: + """ + 处理输入图像,转换为 PIL Image 格式。 + + Args: + image_input: 图像输入,支持文件路径、bytes 或 PIL Image + + Returns: + PIL Image 对象 + """ + if isinstance(image_input, Image.Image): + return image_input + + if isinstance(image_input, bytes): + return Image.open(io.BytesIO(image_input)) + + if isinstance(image_input, str): + # 检查是否为 base64 编码 + if image_input.startswith("data:image"): + # 移除 data:image/xxx;base64, 前缀 + base64_data = image_input.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + return Image.open(io.BytesIO(image_bytes)) + elif image_input.startswith("/") or image_input.startswith("."): + # 文件路径 + return Image.open(image_input) + else: + # 尝试作为 base64 解码 + try: + image_bytes = base64.b64decode(image_input) + return Image.open(io.BytesIO(image_bytes)) + except Exception: + raise ValueError("无法识别的图像输入格式") + + raise ValueError("不支持的图像输入类型") + + def recognize( + self, + image_input: Union[str, bytes, Image.Image], + max_new_tokens: int = 4096, + temperature: float = 0.01, + top_p: float = 0.9 + ) -> dict: + """ + 执行 OCR 识别。 + + Args: + image_input: 图像输入 + max_new_tokens: 最大生成 token 数 + temperature: 温度参数 + top_p: top-p 采样参数 + + Returns: + 包含识别结果的字典 + """ + try: + # 处理图像 + image = self._process_image(image_input) + + # 确保图像为 RGB 模式 + if image.mode != "RGB": + image = image.convert("RGB") + + # 构建对话消息 + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "请识别图片中的所有文字内容,并按照原文排版格式输出。"} + ] + } + ] + + # 应用 chat template + inputs = self._tokenizer.apply_chat_template( + messages, + tokenize=True, + return_tensors="pt", + return_dict=True, + add_generation_prompt=True + ) + + # 移动到对应设备 + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + # 生成 + with torch.no_grad(): + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=True, + pad_token_id=self._tokenizer.eos_token_id + ) + + # 解码结果 + generated_ids = outputs[0][inputs["input_ids"].shape[1]:] + result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True) + + return { + "success": True, + "text": result_text, + "error": None + } + + except Exception as e: + logger.error(f"OCR 识别失败: {str(e)}") + return { + "success": False, + "text": None, + "error": str(e) + } + + def recognize_table( + self, + image_input: Union[str, bytes, Image.Image], + max_new_tokens: int = 4096 + ) -> dict: + """ + 识别表格内容。 + + Args: + image_input: 图像输入 + max_new_tokens: 最大生成 token 数 + + Returns: + 包含表格识别结果的字典 + """ + try: + image = self._process_image(image_input) + + if image.mode != "RGB": + image = image.convert("RGB") + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "请识别图片中的表格内容,并以 Markdown 表格格式输出。"} + ] + } + ] + + inputs = self._tokenizer.apply_chat_template( + messages, + tokenize=True, + return_tensors="pt", + return_dict=True, + add_generation_prompt=True + ) + + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=0.01, + top_p=0.9, + do_sample=True, + pad_token_id=self._tokenizer.eos_token_id + ) + + generated_ids = outputs[0][inputs["input_ids"].shape[1]:] + result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True) + + return { + "success": True, + "text": result_text, + "error": None + } + + except Exception as e: + logger.error(f"表格识别失败: {str(e)}") + return { + "success": False, + "text": None, + "error": str(e) + } + + def recognize_formula( + self, + image_input: Union[str, bytes, Image.Image], + max_new_tokens: int = 2048 + ) -> dict: + """ + 识别数学公式。 + + Args: + image_input: 图像输入 + max_new_tokens: 最大生成 token 数 + + Returns: + 包含公式识别结果的字典 + """ + try: + image = self._process_image(image_input) + + if image.mode != "RGB": + image = image.convert("RGB") + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "请识别图片中的数学公式,并以 LaTeX 格式输出。"} + ] + } + ] + + inputs = self._tokenizer.apply_chat_template( + messages, + tokenize=True, + return_tensors="pt", + return_dict=True, + add_generation_prompt=True + ) + + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=0.01, + top_p=0.9, + do_sample=True, + pad_token_id=self._tokenizer.eos_token_id + ) + + generated_ids = outputs[0][inputs["input_ids"].shape[1]:] + result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True) + + return { + "success": True, + "text": result_text, + "error": None + } + + except Exception as e: + logger.error(f"公式识别失败: {str(e)}") + return { + "success": False, + "text": None, + "error": str(e) + } + + +# 全局 OCR 服务实例 +ocr_service: Optional[OCRService] = None + + +def get_ocr_service() -> OCRService: + """获取 OCR 服务单例实例。""" + global ocr_service + if ocr_service is None: + ocr_service = OCRService() + return ocr_service