调整报错
This commit is contained in:
parent
55dc107acd
commit
39fe636ff2
|
|
@ -11,6 +11,7 @@ import ai.djl.pytorch.engine.PtModel;
|
||||||
import ai.djl.pytorch.engine.PtNDArray;
|
import ai.djl.pytorch.engine.PtNDArray;
|
||||||
import ai.djl.repository.zoo.Criteria;
|
import ai.djl.repository.zoo.Criteria;
|
||||||
import ai.djl.repository.zoo.ZooModel;
|
import ai.djl.repository.zoo.ZooModel;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
import ai.djl.translate.TranslateException;
|
import ai.djl.translate.TranslateException;
|
||||||
import ai.djl.translate.Translator;
|
import ai.djl.translate.Translator;
|
||||||
import ai.djl.translate.TranslatorContext;
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
|
@ -76,7 +77,7 @@ public class GlmOcrModel {
|
||||||
.optModelPath(modelPath)
|
.optModelPath(modelPath)
|
||||||
.optEngine("PyTorch")
|
.optEngine("PyTorch")
|
||||||
.optTranslator(new GlmOcrTranslator(config))
|
.optTranslator(new GlmOcrTranslator(config))
|
||||||
.optProgress(new ProgressListener())
|
.optProgress(new ProgressBar())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
model = criteria.loadModel();
|
model = criteria.loadModel();
|
||||||
|
|
@ -100,10 +101,10 @@ public class GlmOcrModel {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
|
// private static class ProgressListener implements ai.djl.training.listener.ProgressListener {
|
||||||
@Override
|
// @Override
|
||||||
public void progressUpdate(int progress, String message) {
|
// public void progressUpdate(int progress, String message) {
|
||||||
log.debug("Model loading progress: {}% - {}", progress, message);
|
// log.debug("Model loading progress: {}% - {}", progress, message);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ public class GlmOcrTranslator implements Translator<BufferedImage, String> {
|
||||||
int width = processedImage.getWidth();
|
int width = processedImage.getWidth();
|
||||||
int height = processedImage.getHeight();
|
int height = processedImage.getHeight();
|
||||||
|
|
||||||
float[] pixels = new float[3 * height * width];
|
double[] pixels = new double[3 * height * width];
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
|
|
||||||
for (int c = 0; c < 3; c++) {
|
for (int c = 0; c < 3; c++) {
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ public class TokenizerService {
|
||||||
return new long[0];
|
return new long[0];
|
||||||
}
|
}
|
||||||
Encoding encoding = tokenizer.encode(text);
|
Encoding encoding = tokenizer.encode(text);
|
||||||
return Arrays.stream(encoding.getIds()).asLongStream().toArray();
|
return Arrays.stream(encoding.getIds()).toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String decode(long[] ids) {
|
public String decode(long[] ids) {
|
||||||
|
|
@ -51,8 +51,7 @@ public class TokenizerService {
|
||||||
.filter(id -> id > 0 && id < 151936) // GLM token范围
|
.filter(id -> id > 0 && id < 151936) // GLM token范围
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|
||||||
String[] tokens = tokenizer.decode(filteredIds);
|
return tokenizer.decode(filteredIds);
|
||||||
return String.join("", tokens);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public String decodeWithPrompt(long[] ids, String prompt) {
|
public String decodeWithPrompt(long[] ids, String prompt) {
|
||||||
|
|
@ -65,7 +64,8 @@ public class TokenizerService {
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getVocabSize() {
|
public int getVocabSize() {
|
||||||
return tokenizer != null ? tokenizer.getVocabularySize() : 0;
|
// return tokenizer != null ? tokenizer.getVocabularySize() : 0;
|
||||||
|
return tokenizer != null ? tokenizer.getMaxLength() : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isAvailable() {
|
public boolean isAvailable() {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ spring:
|
||||||
# GLM-OCR配置(纯Java本地部署)
|
# GLM-OCR配置(纯Java本地部署)
|
||||||
glm-ocr:
|
glm-ocr:
|
||||||
# 模型本地路径(支持相对路径或绝对路径)
|
# 模型本地路径(支持相对路径或绝对路径)
|
||||||
model-path: ./models/GLM-OCR
|
model-path: 'D:/development/community/GLM-OCR'
|
||||||
# 推理设备: cpu, gpu(0), gpu(1)
|
# 推理设备: cpu, gpu(0), gpu(1)
|
||||||
device: cpu
|
device: cpu
|
||||||
# 精度: fp32, fp16, bf16, int8
|
# 精度: fp32, fp16, bf16, int8
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue