调整报错

This commit is contained in:
黎润豪 2026-02-25 10:33:28 +08:00
parent 55dc107acd
commit 39fe636ff2
4 changed files with 14 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import ai.djl.pytorch.engine.PtModel;
import ai.djl.pytorch.engine.PtNDArray; import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel; import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException; import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator; import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext; import ai.djl.translate.TranslatorContext;
@ -76,7 +77,7 @@ public class GlmOcrModel {
.optModelPath(modelPath) .optModelPath(modelPath)
.optEngine("PyTorch") .optEngine("PyTorch")
.optTranslator(new GlmOcrTranslator(config)) .optTranslator(new GlmOcrTranslator(config))
.optProgress(new ProgressListener()) .optProgress(new ProgressBar())
.build(); .build();
model = criteria.loadModel(); model = criteria.loadModel();
@ -100,10 +101,10 @@ public class GlmOcrModel {
init(); init();
} }
private static class ProgressListener implements ai.djl.training.listener.ProgressListener { // private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
@Override // @Override
public void progressUpdate(int progress, String message) { // public void progressUpdate(int progress, String message) {
log.debug("Model loading progress: {}% - {}", progress, message); // log.debug("Model loading progress: {}% - {}", progress, message);
} // }
} // }
} }

View File

@ -49,7 +49,7 @@ public class GlmOcrTranslator implements Translator<BufferedImage, String> {
int width = processedImage.getWidth(); int width = processedImage.getWidth();
int height = processedImage.getHeight(); int height = processedImage.getHeight();
float[] pixels = new float[3 * height * width]; double[] pixels = new double[3 * height * width];
int idx = 0; int idx = 0;
for (int c = 0; c < 3; c++) { for (int c = 0; c < 3; c++) {

View File

@ -39,7 +39,7 @@ public class TokenizerService {
return new long[0]; return new long[0];
} }
Encoding encoding = tokenizer.encode(text); Encoding encoding = tokenizer.encode(text);
return Arrays.stream(encoding.getIds()).asLongStream().toArray(); return Arrays.stream(encoding.getIds()).toArray();
} }
public String decode(long[] ids) { public String decode(long[] ids) {
@ -51,8 +51,7 @@ public class TokenizerService {
.filter(id -> id > 0 && id < 151936) // GLM token范围 .filter(id -> id > 0 && id < 151936) // GLM token范围
.toArray(); .toArray();
String[] tokens = tokenizer.decode(filteredIds); return tokenizer.decode(filteredIds);
return String.join("", tokens);
} }
public String decodeWithPrompt(long[] ids, String prompt) { public String decodeWithPrompt(long[] ids, String prompt) {
@ -65,7 +64,8 @@ public class TokenizerService {
} }
public int getVocabSize() { public int getVocabSize() {
return tokenizer != null ? tokenizer.getVocabularySize() : 0; // return tokenizer != null ? tokenizer.getVocabularySize() : 0;
return tokenizer != null ? tokenizer.getMaxLength() : 0;
} }
public boolean isAvailable() { public boolean isAvailable() {

View File

@ -12,7 +12,7 @@ spring:
# GLM-OCR配置纯Java本地部署 # GLM-OCR配置纯Java本地部署
glm-ocr: glm-ocr:
# 模型本地路径(支持相对路径或绝对路径) # 模型本地路径(支持相对路径或绝对路径)
model-path: ./models/GLM-OCR model-path: 'D:/development/community/GLM-OCR'
# 推理设备: cpu, gpu(0), gpu(1) # 推理设备: cpu, gpu(0), gpu(1)
device: cpu device: cpu
# 精度: fp32, fp16, bf16, int8 # 精度: fp32, fp16, bf16, int8