feat: 添加 GLM-OCR 识别功能
- 新增 OCR 服务模块支持文字/表格/公式识别 - 添加 OCR API 路由(文件上传和 base64 方式) - 更新配置以支持 GLM-OCR 模型 - 添加必要的依赖项(torch, transformers, accelerate)
This commit is contained in:
parent
52db191cc6
commit
17ef9cc836
|
|
@ -11,4 +11,10 @@ PORT=8000
|
||||||
DATABASE_URL=sqlite:///./app.db
|
DATABASE_URL=sqlite:///./app.db
|
||||||
|
|
||||||
# Security
|
# Security
|
||||||
SECRET_KEY=your-secret-key-here-change-in-production
|
SECRET_KEY=your-secret-key-here-change-in-production
|
||||||
|
|
||||||
|
# GLM-OCR Model
|
||||||
|
# 模型路径: 可以是 Hugging Face 模型 ID (如 zai-org/GLM-OCR)
|
||||||
|
# 或本地路径 (如 /path/to/local/model)
|
||||||
|
MODEL_PATH=zai-org/GLM-OCR
|
||||||
|
MODEL_DEVICE=auto
|
||||||
|
|
|
||||||
|
|
@ -20,4 +20,10 @@ python-multipart==0.0.12
|
||||||
httpx==0.27.2
|
httpx==0.27.2
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
orjson==3.10.12
|
orjson==3.10.12
|
||||||
|
|
||||||
|
# GLM-OCR Dependencies
|
||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.40.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
accelerate>=0.30.0
|
||||||
|
|
|
||||||
|
|
@ -1 +1,4 @@
|
||||||
# api package
|
# api package
|
||||||
|
from src.api.ocr import router as ocr_router
|
||||||
|
|
||||||
|
__all__ = ["ocr_router"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,191 @@
|
||||||
|
"""OCR API endpoints."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.services.ocr_service import get_ocr_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ocr", tags=["OCR"])
|
||||||
|
|
||||||
|
|
||||||
|
class OCRTType(str, Enum):
|
||||||
|
"""OCR 识别类型枚举。"""
|
||||||
|
text = "text"
|
||||||
|
table = "table"
|
||||||
|
formula = "formula"
|
||||||
|
|
||||||
|
|
||||||
|
class OCRRequest(BaseModel):
|
||||||
|
"""OCR 请求模型(用于 base64 图像)。"""
|
||||||
|
image_base64: str
|
||||||
|
ocr_type: OCRTType = OCRTType.text
|
||||||
|
max_new_tokens: int = 4096
|
||||||
|
|
||||||
|
|
||||||
|
class OCRResponse(BaseModel):
|
||||||
|
"""OCR 响应模型。"""
|
||||||
|
success: bool
|
||||||
|
text: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize", response_model=OCRResponse)
|
||||||
|
async def recognize_text(
|
||||||
|
file: UploadFile = File(..., description="要识别的图像文件"),
|
||||||
|
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
识别图像中的文字内容。
|
||||||
|
|
||||||
|
支持的图片格式: PNG, JPG, JPEG, WEBP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图像文件
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRResponse: 包含识别结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取图像
|
||||||
|
image_bytes = await file.read()
|
||||||
|
|
||||||
|
# 调用 OCR 服务
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
result = ocr_service.recognize(image_bytes, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
|
return OCRResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文字识别失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize/table", response_model=OCRResponse)
|
||||||
|
async def recognize_table(
|
||||||
|
file: UploadFile = File(..., description="包含表格的图像文件"),
|
||||||
|
max_new_tokens: int = Form(default=4096, description="最大生成 token 数")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
识别图像中的表格内容。
|
||||||
|
|
||||||
|
以 Markdown 表格格式输出识别结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图像文件
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRResponse: 包含表格识别结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image_bytes = await file.read()
|
||||||
|
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
result = ocr_service.recognize_table(image_bytes, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
|
return OCRResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"表格识别失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize/formula", response_model=OCRResponse)
|
||||||
|
async def recognize_formula(
|
||||||
|
file: UploadFile = File(..., description="包含数学公式的图像文件"),
|
||||||
|
max_new_tokens: int = Form(default=2048, description="最大生成 token 数")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
识别图像中的数学公式。
|
||||||
|
|
||||||
|
以 LaTeX 格式输出识别结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图像文件
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRResponse: 包含公式识别结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image_bytes = await file.read()
|
||||||
|
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
result = ocr_service.recognize_formula(image_bytes, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
|
return OCRResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"公式识别失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize/base64", response_model=OCRResponse)
|
||||||
|
async def recognize_base64(request: OCRRequest):
|
||||||
|
"""
|
||||||
|
通过 base64 编码的图像进行 OCR 识别。
|
||||||
|
|
||||||
|
支持纯 base64 字符串或 data:image/xxx;base64,... 格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: OCRRequest 包含 base64 图像和识别类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRResponse: 包含识别结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
|
||||||
|
if request.ocr_type == OCRTType.table:
|
||||||
|
result = ocr_service.recognize_table(
|
||||||
|
request.image_base64,
|
||||||
|
max_new_tokens=request.max_new_tokens
|
||||||
|
)
|
||||||
|
elif request.ocr_type == OCRTType.formula:
|
||||||
|
result = ocr_service.recognize_formula(
|
||||||
|
request.image_base64,
|
||||||
|
max_new_tokens=request.max_new_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = ocr_service.recognize(
|
||||||
|
request.image_base64,
|
||||||
|
max_new_tokens=request.max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return OCRResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OCR 识别失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_status():
|
||||||
|
"""
|
||||||
|
获取 OCR 服务状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
服务状态信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
return {
|
||||||
|
"status": "ready",
|
||||||
|
"model_loaded": ocr_service._model is not None,
|
||||||
|
"tokenizer_loaded": ocr_service._tokenizer is not None
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
@ -13,9 +13,13 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
secret_key: str = "your-secret-key-here"
|
secret_key: str = "your-secret-key-here"
|
||||||
|
|
||||||
|
# GLM-OCR Model Settings
|
||||||
|
model_path: str = "zai-org/GLM-OCR" # Hugging Face 模型 ID 或本地路径
|
||||||
|
model_device: str = "auto" # auto, cuda, cpu
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from src.api.ocr import router as ocr_router
|
||||||
from src.core.config import settings
|
from src.core.config import settings
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
|
@ -18,13 +19,17 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 注册 OCR 路由
|
||||||
|
app.include_router(ocr_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {
|
return {
|
||||||
"app": settings.app_name,
|
"app": settings.app_name,
|
||||||
"version": settings.app_version,
|
"version": settings.app_version,
|
||||||
"status": "running"
|
"status": "running",
|
||||||
|
"model": settings.model_path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,4 +46,4 @@ if __name__ == "__main__":
|
||||||
host=settings.host,
|
host=settings.host,
|
||||||
port=settings.port,
|
port=settings.port,
|
||||||
reload=settings.debug
|
reload=settings.debug
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1 +1,4 @@
|
||||||
# services package
|
# services package
|
||||||
|
from src.services.ocr_service import OCRService, get_ocr_service
|
||||||
|
|
||||||
|
__all__ = ["OCRService", "get_ocr_service"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,325 @@
|
||||||
|
"""OCR service module for GLM-OCR model integration."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OCRService:
|
||||||
|
"""GLM-OCR 服务类,负责模型加载和推理。"""
|
||||||
|
|
||||||
|
_instance: Optional["OCRService"] = None
|
||||||
|
_model = None
|
||||||
|
_tokenizer = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._model is None:
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载 GLM-OCR 模型和 tokenizer。"""
|
||||||
|
logger.info(f"正在加载 GLM-OCR 模型: {settings.model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载 tokenizer
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
settings.model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self._model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
settings.model_path,
|
||||||
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置为评估模式
|
||||||
|
self._model.eval()
|
||||||
|
|
||||||
|
logger.info("GLM-OCR 模型加载完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型加载失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_image(self, image_input: Union[str, bytes, Image.Image]) -> Image.Image:
|
||||||
|
"""
|
||||||
|
处理输入图像,转换为 PIL Image 格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input: 图像输入,支持文件路径、bytes 或 PIL Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image 对象
|
||||||
|
"""
|
||||||
|
if isinstance(image_input, Image.Image):
|
||||||
|
return image_input
|
||||||
|
|
||||||
|
if isinstance(image_input, bytes):
|
||||||
|
return Image.open(io.BytesIO(image_input))
|
||||||
|
|
||||||
|
if isinstance(image_input, str):
|
||||||
|
# 检查是否为 base64 编码
|
||||||
|
if image_input.startswith("data:image"):
|
||||||
|
# 移除 data:image/xxx;base64, 前缀
|
||||||
|
base64_data = image_input.split(",", 1)[1]
|
||||||
|
image_bytes = base64.b64decode(base64_data)
|
||||||
|
return Image.open(io.BytesIO(image_bytes))
|
||||||
|
elif image_input.startswith("/") or image_input.startswith("."):
|
||||||
|
# 文件路径
|
||||||
|
return Image.open(image_input)
|
||||||
|
else:
|
||||||
|
# 尝试作为 base64 解码
|
||||||
|
try:
|
||||||
|
image_bytes = base64.b64decode(image_input)
|
||||||
|
return Image.open(io.BytesIO(image_bytes))
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("无法识别的图像输入格式")
|
||||||
|
|
||||||
|
raise ValueError("不支持的图像输入类型")
|
||||||
|
|
||||||
|
def recognize(
|
||||||
|
self,
|
||||||
|
image_input: Union[str, bytes, Image.Image],
|
||||||
|
max_new_tokens: int = 4096,
|
||||||
|
temperature: float = 0.01,
|
||||||
|
top_p: float = 0.9
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
执行 OCR 识别。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input: 图像输入
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
temperature: 温度参数
|
||||||
|
top_p: top-p 采样参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含识别结果的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 处理图像
|
||||||
|
image = self._process_image(image_input)
|
||||||
|
|
||||||
|
# 确保图像为 RGB 模式
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# 构建对话消息
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": image},
|
||||||
|
{"type": "text", "text": "请识别图片中的所有文字内容,并按照原文排版格式输出。"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 应用 chat template
|
||||||
|
inputs = self._tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移动到对应设备
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 生成
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=self._tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解码结果
|
||||||
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||||
|
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"text": result_text,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OCR 识别失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"text": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def recognize_table(
|
||||||
|
self,
|
||||||
|
image_input: Union[str, bytes, Image.Image],
|
||||||
|
max_new_tokens: int = 4096
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
识别表格内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input: 图像输入
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含表格识别结果的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = self._process_image(image_input)
|
||||||
|
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": image},
|
||||||
|
{"type": "text", "text": "请识别图片中的表格内容,并以 Markdown 表格格式输出。"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = self._tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=0.01,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=self._tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||||
|
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"text": result_text,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"表格识别失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"text": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def recognize_formula(
|
||||||
|
self,
|
||||||
|
image_input: Union[str, bytes, Image.Image],
|
||||||
|
max_new_tokens: int = 2048
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
识别数学公式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input: 图像输入
|
||||||
|
max_new_tokens: 最大生成 token 数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含公式识别结果的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = self._process_image(image_input)
|
||||||
|
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": image},
|
||||||
|
{"type": "text", "text": "请识别图片中的数学公式,并以 LaTeX 格式输出。"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = self._tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=0.01,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=self._tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
||||||
|
result_text = self._tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"text": result_text,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"公式识别失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"text": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局 OCR 服务实例
|
||||||
|
ocr_service: Optional[OCRService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ocr_service() -> OCRService:
|
||||||
|
"""获取 OCR 服务单例实例。"""
|
||||||
|
global ocr_service
|
||||||
|
if ocr_service is None:
|
||||||
|
ocr_service = OCRService()
|
||||||
|
return ocr_service
|
||||||
Loading…
Reference in New Issue