调整报错

This commit is contained in:
黎润豪 2026-02-25 10:33:28 +08:00
parent 55dc107acd
commit 39fe636ff2
4 changed files with 14 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import ai.djl.pytorch.engine.PtModel;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
@ -76,7 +77,7 @@ public class GlmOcrModel {
.optModelPath(modelPath)
.optEngine("PyTorch")
.optTranslator(new GlmOcrTranslator(config))
.optProgress(new ProgressListener())
.optProgress(new ProgressBar())
.build();
model = criteria.loadModel();
@ -100,10 +101,10 @@ public class GlmOcrModel {
init();
}
private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
@Override
public void progressUpdate(int progress, String message) {
log.debug("Model loading progress: {}% - {}", progress, message);
}
}
// private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
// @Override
// public void progressUpdate(int progress, String message) {
// log.debug("Model loading progress: {}% - {}", progress, message);
// }
// }
}

View File

@ -49,7 +49,7 @@ public class GlmOcrTranslator implements Translator<BufferedImage, String> {
int width = processedImage.getWidth();
int height = processedImage.getHeight();
float[] pixels = new float[3 * height * width];
double[] pixels = new double[3 * height * width];
int idx = 0;
for (int c = 0; c < 3; c++) {

View File

@ -39,7 +39,7 @@ public class TokenizerService {
return new long[0];
}
Encoding encoding = tokenizer.encode(text);
return Arrays.stream(encoding.getIds()).asLongStream().toArray();
return Arrays.stream(encoding.getIds()).toArray();
}
public String decode(long[] ids) {
@ -51,8 +51,7 @@ public class TokenizerService {
.filter(id -> id > 0 && id < 151936) // GLM token范围
.toArray();
String[] tokens = tokenizer.decode(filteredIds);
return String.join("", tokens);
return tokenizer.decode(filteredIds);
}
public String decodeWithPrompt(long[] ids, String prompt) {
@ -65,7 +64,8 @@ public class TokenizerService {
}
public int getVocabSize() {
return tokenizer != null ? tokenizer.getVocabularySize() : 0;
// return tokenizer != null ? tokenizer.getVocabularySize() : 0;
return tokenizer != null ? tokenizer.getMaxLength() : 0;
}
public boolean isAvailable() {

View File

@ -12,7 +12,7 @@ spring:
# GLM-OCR配置纯Java本地部署
glm-ocr:
# 模型本地路径(支持相对路径或绝对路径)
model-path: ./models/GLM-OCR
model-path: 'D:/development/community/GLM-OCR'
# 推理设备: cpu, gpu(0), gpu(1)
device: cpu
# 精度: fp32, fp16, bf16, int8