feat: 添加 GLM-OCR 识别功能

- 新增 OCR 服务模块支持文字/表格/公式识别
- 添加 OCR API 路由(文件上传和 base64 方式)
- 更新配置以支持 GLM-OCR 模型
- 添加必要的依赖项(torch, transformers, accelerate)
This commit is contained in:
黎润豪 2026-02-25 15:55:50 +08:00
parent 52db191cc6
commit 17ef9cc836
8 changed files with 550 additions and 7 deletions

View File

@ -11,4 +11,10 @@ PORT=8000
DATABASE_URL=sqlite:///./app.db
# Security
SECRET_KEY=your-secret-key-here-change-in-production
SECRET_KEY=your-secret-key-here-change-in-production
# GLM-OCR Model
# 模型路径: 可以是 Hugging Face 模型 ID (如 zai-org/GLM-OCR)
# 或本地路径 (如 /path/to/local/model)
MODEL_PATH=zai-org/GLM-OCR
MODEL_DEVICE=auto

View File

@ -20,4 +20,10 @@ python-multipart==0.0.12
httpx==0.27.2
# Utils
orjson==3.10.12
orjson==3.10.12
# GLM-OCR Dependencies
torch>=2.0.0
transformers>=4.40.0
Pillow>=10.0.0
accelerate>=0.30.0

View File

@ -1 +1,4 @@
# api package
# api package
from src.api.ocr import router as ocr_router
__all__ = ["ocr_router"]

191
src/api/ocr.py Normal file
View File

@ -0,0 +1,191 @@
"""OCR API endpoints."""
import base64
import logging
from enum import Enum
from io import BytesIO
from typing import Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from src.services.ocr_service import get_ocr_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ocr", tags=["OCR"])
class OCRTType(str, Enum):
"""OCR 识别类型枚举。"""
text = "text"
table = "table"
formula = "formula"
class OCRRequest(BaseModel):
"""OCR 请求模型(用于 base64 图像)。"""
image_base64: str
ocr_type: OCRTType = OCRTType.text
max_new_tokens: int = 4096
class OCRResponse(BaseModel):
"""OCR 响应模型。"""
success: bool
text: Optional[str] = None
error: Optional[str] = None
@router.post("/recognize", response_model=OCRResponse)
async def recognize_text(
file: UploadFile = File(..., description="要识别的图像文件"),
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
):
"""
识别图像中的文字内容
支持的图片格式: PNG, JPG, JPEG, WEBP
Args:
file: 上传的图像文件
max_new_tokens: 最大生成 token
Returns:
OCRResponse: 包含识别结果
"""
try:
# 读取图像
image_bytes = await file.read()
# 调用 OCR 服务
ocr_service = get_ocr_service()
result = ocr_service.recognize(image_bytes, max_new_tokens=max_new_tokens)
return OCRResponse(**result)
except Exception as e:
logger.error(f"文字识别失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/recognize/table", response_model=OCRResponse)
async def recognize_table(
file: UploadFile = File(..., description="包含表格的图像文件"),
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
):
"""
识别图像中的表格内容
Markdown 表格格式输出识别结果
Args:
file: 上传的图像文件
max_new_tokens: 最大生成 token
Returns:
OCRResponse: 包含表格识别结果
"""
try:
image_bytes = await file.read()
ocr_service = get_ocr_service()
result = ocr_service.recognize_table(image_bytes, max_new_tokens=max_new_tokens)
return OCRResponse(**result)
except Exception as e:
logger.error(f"表格识别失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/recognize/formula", response_model=OCRResponse)
async def recognize_formula(
file: UploadFile = File(..., description="包含数学公式的图像文件"),
max_new_tokens: int = Form(default=2048, description="最大生成 token 数")
):
"""
识别图像中的数学公式
LaTeX 格式输出识别结果
Args:
file: 上传的图像文件
max_new_tokens: 最大生成 token
Returns:
OCRResponse: 包含公式识别结果
"""
try:
image_bytes = await file.read()
ocr_service = get_ocr_service()
result = ocr_service.recognize_formula(image_bytes, max_new_tokens=max_new_tokens)
return OCRResponse(**result)
except Exception as e:
logger.error(f"公式识别失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/recognize/base64", response_model=OCRResponse)
async def recognize_base64(request: OCRRequest):
"""
通过 base64 编码的图像进行 OCR 识别
支持纯 base64 字符串或 data:image/xxx;base64,... 格式
Args:
request: OCRRequest 包含 base64 图像和识别类型
Returns:
OCRResponse: 包含识别结果
"""
try:
ocr_service = get_ocr_service()
if request.ocr_type == OCRTType.table:
result = ocr_service.recognize_table(
request.image_base64,
max_new_tokens=request.max_new_tokens
)
elif request.ocr_type == OCRTType.formula:
result = ocr_service.recognize_formula(
request.image_base64,
max_new_tokens=request.max_new_tokens
)
else:
result = ocr_service.recognize(
request.image_base64,
max_new_tokens=request.max_new_tokens
)
return OCRResponse(**result)
except Exception as e:
logger.error(f"OCR 识别失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status")
async def get_status():
"""
获取 OCR 服务状态
Returns:
服务状态信息
"""
try:
ocr_service = get_ocr_service()
return {
"status": "ready",
"model_loaded": ocr_service._model is not None,
"tokenizer_loaded": ocr_service._tokenizer is not None
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}

View File

@ -13,9 +13,13 @@ class Settings(BaseSettings):
secret_key: str = "your-secret-key-here"
# GLM-OCR Model Settings
model_path: str = "zai-org/GLM-OCR" # Hugging Face 模型 ID 或本地路径
model_device: str = "auto" # auto, cuda, cpu
class Config:
env_file = ".env"
case_sensitive = False
settings = Settings()
settings = Settings()

View File

@ -1,6 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.api.ocr import router as ocr_router
from src.core.config import settings
app = FastAPI(
@ -18,13 +19,17 @@ app.add_middleware(
allow_headers=["*"],
)
# 注册 OCR 路由
app.include_router(ocr_router)
@app.get("/")
async def root():
return {
"app": settings.app_name,
"version": settings.app_version,
"status": "running"
"status": "running",
"model": settings.model_path
}
@ -41,4 +46,4 @@ if __name__ == "__main__":
host=settings.host,
port=settings.port,
reload=settings.debug
)
)

View File

@ -1 +1,4 @@
# services package
# services package
from src.services.ocr_service import OCRService, get_ocr_service
__all__ = ["OCRService", "get_ocr_service"]

325
src/services/ocr_service.py Normal file
View File

@ -0,0 +1,325 @@
"""OCR service module for GLM-OCR model integration."""
import base64
import io
import logging
from typing import Optional, Union
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.core.config import settings
logger = logging.getLogger(__name__)
class OCRService:
"""GLM-OCR 服务类,负责模型加载和推理。"""
_instance: Optional["OCRService"] = None
_model = None
_tokenizer = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._model is None:
self._load_model()
def _load_model(self):
"""加载 GLM-OCR 模型和 tokenizer。"""
logger.info(f"正在加载 GLM-OCR 模型: {settings.model_path}")
try:
# 加载 tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
settings.model_path,
trust_remote_code=True
)
# 加载模型
self._model = AutoModelForCausalLM.from_pretrained(
settings.model_path,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# 设置为评估模式
self._model.eval()
logger.info("GLM-OCR 模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def _process_image(self, image_input: Union[str, bytes, Image.Image]) -> Image.Image:
"""
处理输入图像转换为 PIL Image 格式
Args:
image_input: 图像输入支持文件路径bytes PIL Image
Returns:
PIL Image 对象
"""
if isinstance(image_input, Image.Image):
return image_input
if isinstance(image_input, bytes):
return Image.open(io.BytesIO(image_input))
if isinstance(image_input, str):
# 检查是否为 base64 编码
if image_input.startswith("data:image"):
# 移除 data:image/xxx;base64, 前缀
base64_data = image_input.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
return Image.open(io.BytesIO(image_bytes))
elif image_input.startswith("/") or image_input.startswith("."):
# 文件路径
return Image.open(image_input)
else:
# 尝试作为 base64 解码
try:
image_bytes = base64.b64decode(image_input)
return Image.open(io.BytesIO(image_bytes))
except Exception:
raise ValueError("无法识别的图像输入格式")
raise ValueError("不支持的图像输入类型")
def recognize(
self,
image_input: Union[str, bytes, Image.Image],
max_new_tokens: int = 4096,
temperature: float = 0.01,
top_p: float = 0.9
) -> dict:
"""
执行 OCR 识别
Args:
image_input: 图像输入
max_new_tokens: 最大生成 token
temperature: 温度参数
top_p: top-p 采样参数
Returns:
包含识别结果的字典
"""
try:
# 处理图像
image = self._process_image(image_input)
# 确保图像为 RGB 模式
if image.mode != "RGB":
image = image.convert("RGB")
# 构建对话消息
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "请识别图片中的所有文字内容,并按照原文排版格式输出。"}
]
}
]
# 应用 chat template
inputs = self._tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True
)
# 移动到对应设备
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# 生成
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self._tokenizer.eos_token_id
)
# 解码结果
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
return {
"success": True,
"text": result_text,
"error": None
}
except Exception as e:
logger.error(f"OCR 识别失败: {str(e)}")
return {
"success": False,
"text": None,
"error": str(e)
}
def recognize_table(
self,
image_input: Union[str, bytes, Image.Image],
max_new_tokens: int = 4096
) -> dict:
"""
识别表格内容
Args:
image_input: 图像输入
max_new_tokens: 最大生成 token
Returns:
包含表格识别结果的字典
"""
try:
image = self._process_image(image_input)
if image.mode != "RGB":
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "请识别图片中的表格内容,并以 Markdown 表格格式输出。"}
]
}
]
inputs = self._tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.01,
top_p=0.9,
do_sample=True,
pad_token_id=self._tokenizer.eos_token_id
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
return {
"success": True,
"text": result_text,
"error": None
}
except Exception as e:
logger.error(f"表格识别失败: {str(e)}")
return {
"success": False,
"text": None,
"error": str(e)
}
def recognize_formula(
self,
image_input: Union[str, bytes, Image.Image],
max_new_tokens: int = 2048
) -> dict:
"""
识别数学公式
Args:
image_input: 图像输入
max_new_tokens: 最大生成 token
Returns:
包含公式识别结果的字典
"""
try:
image = self._process_image(image_input)
if image.mode != "RGB":
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "请识别图片中的数学公式,并以 LaTeX 格式输出。"}
]
}
]
inputs = self._tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.01,
top_p=0.9,
do_sample=True,
pad_token_id=self._tokenizer.eos_token_id
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
return {
"success": True,
"text": result_text,
"error": None
}
except Exception as e:
logger.error(f"公式识别失败: {str(e)}")
return {
"success": False,
"text": None,
"error": str(e)
}
# 全局 OCR 服务实例
ocr_service: Optional[OCRService] = None
def get_ocr_service() -> OCRService:
"""获取 OCR 服务单例实例。"""
global ocr_service
if ocr_service is None:
ocr_service = OCRService()
return ocr_service