feat: 添加 GLM-OCR 识别功能
- 新增 OCR 服务模块支持文字/表格/公式识别 - 添加 OCR API 路由(文件上传和 base64 方式) - 更新配置以支持 GLM-OCR 模型 - 添加必要的依赖项(torch, transformers, accelerate)
This commit is contained in:
parent
52db191cc6
commit
17ef9cc836
|
|
@ -12,3 +12,9 @@ DATABASE_URL=sqlite:///./app.db
|
|||
|
||||
# Security
|
||||
SECRET_KEY=your-secret-key-here-change-in-production
|
||||
|
||||
# GLM-OCR Model
|
||||
# 模型路径: 可以是 Hugging Face 模型 ID (如 zai-org/GLM-OCR)
|
||||
# 或本地路径 (如 /path/to/local/model)
|
||||
MODEL_PATH=zai-org/GLM-OCR
|
||||
MODEL_DEVICE=auto
|
||||
|
|
|
|||
|
|
@ -21,3 +21,9 @@ httpx==0.27.2
|
|||
|
||||
# Utils
|
||||
orjson==3.10.12
|
||||
|
||||
# GLM-OCR Dependencies
|
||||
torch>=2.0.0
|
||||
transformers>=4.40.0
|
||||
Pillow>=10.0.0
|
||||
accelerate>=0.30.0
|
||||
|
|
|
|||
|
|
@ -1 +1,4 @@
|
|||
# api package
|
||||
from src.api.ocr import router as ocr_router
|
||||
|
||||
__all__ = ["ocr_router"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,191 @@
|
|||
"""OCR API endpoints."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.services.ocr_service import get_ocr_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ocr", tags=["OCR"])
|
||||
|
||||
|
||||
class OCRTType(str, Enum):
|
||||
"""OCR 识别类型枚举。"""
|
||||
text = "text"
|
||||
table = "table"
|
||||
formula = "formula"
|
||||
|
||||
|
||||
class OCRRequest(BaseModel):
|
||||
"""OCR 请求模型(用于 base64 图像)。"""
|
||||
image_base64: str
|
||||
ocr_type: OCRTType = OCRTType.text
|
||||
max_new_tokens: int = 4096
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""OCR 响应模型。"""
|
||||
success: bool
|
||||
text: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/recognize", response_model=OCRResponse)
|
||||
async def recognize_text(
|
||||
file: UploadFile = File(..., description="要识别的图像文件"),
|
||||
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
|
||||
):
|
||||
"""
|
||||
识别图像中的文字内容。
|
||||
|
||||
支持的图片格式: PNG, JPG, JPEG, WEBP
|
||||
|
||||
Args:
|
||||
file: 上传的图像文件
|
||||
max_new_tokens: 最大生成 token 数
|
||||
|
||||
Returns:
|
||||
OCRResponse: 包含识别结果
|
||||
"""
|
||||
try:
|
||||
# 读取图像
|
||||
image_bytes = await file.read()
|
||||
|
||||
# 调用 OCR 服务
|
||||
ocr_service = get_ocr_service()
|
||||
result = ocr_service.recognize(image_bytes, max_new_tokens=max_new_tokens)
|
||||
|
||||
return OCRResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文字识别失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/recognize/table", response_model=OCRResponse)
|
||||
async def recognize_table(
|
||||
file: UploadFile = File(..., description="包含表格的图像文件"),
|
||||
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
|
||||
):
|
||||
"""
|
||||
识别图像中的表格内容。
|
||||
|
||||
以 Markdown 表格格式输出识别结果。
|
||||
|
||||
Args:
|
||||
file: 上传的图像文件
|
||||
max_new_tokens: 最大生成 token 数
|
||||
|
||||
Returns:
|
||||
OCRResponse: 包含表格识别结果
|
||||
"""
|
||||
try:
|
||||
image_bytes = await file.read()
|
||||
|
||||
ocr_service = get_ocr_service()
|
||||
result = ocr_service.recognize_table(image_bytes, max_new_tokens=max_new_tokens)
|
||||
|
||||
return OCRResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表格识别失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/recognize/formula", response_model=OCRResponse)
|
||||
async def recognize_formula(
|
||||
file: UploadFile = File(..., description="包含数学公式的图像文件"),
|
||||
max_new_tokens: int = Form(default=2048, description="最大生成 token 数")
|
||||
):
|
||||
"""
|
||||
识别图像中的数学公式。
|
||||
|
||||
以 LaTeX 格式输出识别结果。
|
||||
|
||||
Args:
|
||||
file: 上传的图像文件
|
||||
max_new_tokens: 最大生成 token 数
|
||||
|
||||
Returns:
|
||||
OCRResponse: 包含公式识别结果
|
||||
"""
|
||||
try:
|
||||
image_bytes = await file.read()
|
||||
|
||||
ocr_service = get_ocr_service()
|
||||
result = ocr_service.recognize_formula(image_bytes, max_new_tokens=max_new_tokens)
|
||||
|
||||
return OCRResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"公式识别失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/recognize/base64", response_model=OCRResponse)
|
||||
async def recognize_base64(request: OCRRequest):
|
||||
"""
|
||||
通过 base64 编码的图像进行 OCR 识别。
|
||||
|
||||
支持纯 base64 字符串或 data:image/xxx;base64,... 格式。
|
||||
|
||||
Args:
|
||||
request: OCRRequest 包含 base64 图像和识别类型
|
||||
|
||||
Returns:
|
||||
OCRResponse: 包含识别结果
|
||||
"""
|
||||
try:
|
||||
ocr_service = get_ocr_service()
|
||||
|
||||
if request.ocr_type == OCRTType.table:
|
||||
result = ocr_service.recognize_table(
|
||||
request.image_base64,
|
||||
max_new_tokens=request.max_new_tokens
|
||||
)
|
||||
elif request.ocr_type == OCRTType.formula:
|
||||
result = ocr_service.recognize_formula(
|
||||
request.image_base64,
|
||||
max_new_tokens=request.max_new_tokens
|
||||
)
|
||||
else:
|
||||
result = ocr_service.recognize(
|
||||
request.image_base64,
|
||||
max_new_tokens=request.max_new_tokens
|
||||
)
|
||||
|
||||
return OCRResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR 识别失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status():
|
||||
"""
|
||||
获取 OCR 服务状态。
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
try:
|
||||
ocr_service = get_ocr_service()
|
||||
return {
|
||||
"status": "ready",
|
||||
"model_loaded": ocr_service._model is not None,
|
||||
"tokenizer_loaded": ocr_service._tokenizer is not None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
|
@ -13,6 +13,10 @@ class Settings(BaseSettings):
|
|||
|
||||
secret_key: str = "your-secret-key-here"
|
||||
|
||||
# GLM-OCR Model Settings
|
||||
model_path: str = "zai-org/GLM-OCR" # Hugging Face 模型 ID 或本地路径
|
||||
model_device: str = "auto" # auto, cuda, cpu
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.api.ocr import router as ocr_router
|
||||
from src.core.config import settings
|
||||
|
||||
app = FastAPI(
|
||||
|
|
@ -18,13 +19,17 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册 OCR 路由
|
||||
app.include_router(ocr_router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"app": settings.app_name,
|
||||
"version": settings.app_version,
|
||||
"status": "running"
|
||||
"status": "running",
|
||||
"model": settings.model_path
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,4 @@
|
|||
# services package
|
||||
from src.services.ocr_service import OCRService, get_ocr_service
|
||||
|
||||
__all__ = ["OCRService", "get_ocr_service"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,325 @@
|
|||
"""OCR service module for GLM-OCR model integration."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCRService:
|
||||
"""GLM-OCR 服务类,负责模型加载和推理。"""
|
||||
|
||||
_instance: Optional["OCRService"] = None
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._model is None:
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载 GLM-OCR 模型和 tokenizer。"""
|
||||
logger.info(f"正在加载 GLM-OCR 模型: {settings.model_path}")
|
||||
|
||||
try:
|
||||
# 加载 tokenizer
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 加载模型
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
settings.model_path,
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 设置为评估模式
|
||||
self._model.eval()
|
||||
|
||||
logger.info("GLM-OCR 模型加载完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def _process_image(self, image_input: Union[str, bytes, Image.Image]) -> Image.Image:
|
||||
"""
|
||||
处理输入图像,转换为 PIL Image 格式。
|
||||
|
||||
Args:
|
||||
image_input: 图像输入,支持文件路径、bytes 或 PIL Image
|
||||
|
||||
Returns:
|
||||
PIL Image 对象
|
||||
"""
|
||||
if isinstance(image_input, Image.Image):
|
||||
return image_input
|
||||
|
||||
if isinstance(image_input, bytes):
|
||||
return Image.open(io.BytesIO(image_input))
|
||||
|
||||
if isinstance(image_input, str):
|
||||
# 检查是否为 base64 编码
|
||||
if image_input.startswith("data:image"):
|
||||
# 移除 data:image/xxx;base64, 前缀
|
||||
base64_data = image_input.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
return Image.open(io.BytesIO(image_bytes))
|
||||
elif image_input.startswith("/") or image_input.startswith("."):
|
||||
# 文件路径
|
||||
return Image.open(image_input)
|
||||
else:
|
||||
# 尝试作为 base64 解码
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_input)
|
||||
return Image.open(io.BytesIO(image_bytes))
|
||||
except Exception:
|
||||
raise ValueError("无法识别的图像输入格式")
|
||||
|
||||
raise ValueError("不支持的图像输入类型")
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
image_input: Union[str, bytes, Image.Image],
|
||||
max_new_tokens: int = 4096,
|
||||
temperature: float = 0.01,
|
||||
top_p: float = 0.9
|
||||
) -> dict:
|
||||
"""
|
||||
执行 OCR 识别。
|
||||
|
||||
Args:
|
||||
image_input: 图像输入
|
||||
max_new_tokens: 最大生成 token 数
|
||||
temperature: 温度参数
|
||||
top_p: top-p 采样参数
|
||||
|
||||
Returns:
|
||||
包含识别结果的字典
|
||||
"""
|
||||
try:
|
||||
# 处理图像
|
||||
image = self._process_image(image_input)
|
||||
|
||||
# 确保图像为 RGB 模式
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# 构建对话消息
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": "请识别图片中的所有文字内容,并按照原文排版格式输出。"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 应用 chat template
|
||||
inputs = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
# 移动到对应设备
|
||||
if torch.cuda.is_available():
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
# 生成
|
||||
with torch.no_grad():
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
do_sample=True,
|
||||
pad_token_id=self._tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# 解码结果
|
||||
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"text": result_text,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR 识别失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"text": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def recognize_table(
|
||||
self,
|
||||
image_input: Union[str, bytes, Image.Image],
|
||||
max_new_tokens: int = 4096
|
||||
) -> dict:
|
||||
"""
|
||||
识别表格内容。
|
||||
|
||||
Args:
|
||||
image_input: 图像输入
|
||||
max_new_tokens: 最大生成 token 数
|
||||
|
||||
Returns:
|
||||
包含表格识别结果的字典
|
||||
"""
|
||||
try:
|
||||
image = self._process_image(image_input)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": "请识别图片中的表格内容,并以 Markdown 表格格式输出。"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.01,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=self._tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"text": result_text,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表格识别失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"text": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def recognize_formula(
|
||||
self,
|
||||
image_input: Union[str, bytes, Image.Image],
|
||||
max_new_tokens: int = 2048
|
||||
) -> dict:
|
||||
"""
|
||||
识别数学公式。
|
||||
|
||||
Args:
|
||||
image_input: 图像输入
|
||||
max_new_tokens: 最大生成 token 数
|
||||
|
||||
Returns:
|
||||
包含公式识别结果的字典
|
||||
"""
|
||||
try:
|
||||
image = self._process_image(image_input)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": "请识别图片中的数学公式,并以 LaTeX 格式输出。"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.01,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=self._tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"text": result_text,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"公式识别失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"text": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 全局 OCR 服务实例
|
||||
ocr_service: Optional[OCRService] = None
|
||||
|
||||
|
||||
def get_ocr_service() -> OCRService:
|
||||
"""获取 OCR 服务单例实例。"""
|
||||
global ocr_service
|
||||
if ocr_service is None:
|
||||
ocr_service = OCRService()
|
||||
return ocr_service
|
||||
Loading…
Reference in New Issue