调整报错
This commit is contained in:
parent
55dc107acd
commit
39fe636ff2
|
|
@ -11,6 +11,7 @@ import ai.djl.pytorch.engine.PtModel;
|
|||
import ai.djl.pytorch.engine.PtNDArray;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
|
@ -76,7 +77,7 @@ public class GlmOcrModel {
|
|||
.optModelPath(modelPath)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(new GlmOcrTranslator(config))
|
||||
.optProgress(new ProgressListener())
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
|
||||
model = criteria.loadModel();
|
||||
|
|
@ -100,10 +101,10 @@ public class GlmOcrModel {
|
|||
init();
|
||||
}
|
||||
|
||||
private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
|
||||
@Override
|
||||
public void progressUpdate(int progress, String message) {
|
||||
log.debug("Model loading progress: {}% - {}", progress, message);
|
||||
}
|
||||
}
|
||||
// private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
|
||||
// @Override
|
||||
// public void progressUpdate(int progress, String message) {
|
||||
// log.debug("Model loading progress: {}% - {}", progress, message);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ public class GlmOcrTranslator implements Translator<BufferedImage, String> {
|
|||
int width = processedImage.getWidth();
|
||||
int height = processedImage.getHeight();
|
||||
|
||||
float[] pixels = new float[3 * height * width];
|
||||
double[] pixels = new double[3 * height * width];
|
||||
int idx = 0;
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ public class TokenizerService {
|
|||
return new long[0];
|
||||
}
|
||||
Encoding encoding = tokenizer.encode(text);
|
||||
return Arrays.stream(encoding.getIds()).asLongStream().toArray();
|
||||
return Arrays.stream(encoding.getIds()).toArray();
|
||||
}
|
||||
|
||||
public String decode(long[] ids) {
|
||||
|
|
@ -51,8 +51,7 @@ public class TokenizerService {
|
|||
.filter(id -> id > 0 && id < 151936) // GLM token范围
|
||||
.toArray();
|
||||
|
||||
String[] tokens = tokenizer.decode(filteredIds);
|
||||
return String.join("", tokens);
|
||||
return tokenizer.decode(filteredIds);
|
||||
}
|
||||
|
||||
public String decodeWithPrompt(long[] ids, String prompt) {
|
||||
|
|
@ -65,7 +64,8 @@ public class TokenizerService {
|
|||
}
|
||||
|
||||
public int getVocabSize() {
|
||||
return tokenizer != null ? tokenizer.getVocabularySize() : 0;
|
||||
// return tokenizer != null ? tokenizer.getVocabularySize() : 0;
|
||||
return tokenizer != null ? tokenizer.getMaxLength() : 0;
|
||||
}
|
||||
|
||||
public boolean isAvailable() {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ spring:
|
|||
# GLM-OCR配置(纯Java本地部署)
|
||||
glm-ocr:
|
||||
# 模型本地路径(支持相对路径或绝对路径)
|
||||
model-path: ./models/GLM-OCR
|
||||
model-path: 'D:/development/community/GLM-OCR'
|
||||
# 推理设备: cpu, gpu(0), gpu(1)
|
||||
device: cpu
|
||||
# 精度: fp32, fp16, bf16, int8
|
||||
|
|
|
|||
Loading…
Reference in New Issue