glmocrdemojava/src/main/java/com/example/glmocr/tokenizer/TokenizerService.java

75 lines
2.1 KiB
Java
Raw Normal View History

package com.example.glmocr.tokenizer;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
@Slf4j
public class TokenizerService {
private HuggingFaceTokenizer tokenizer;
private final String modelPath;
public TokenizerService(String modelPath) {
this.modelPath = modelPath;
init();
}
private void init() {
try {
Path tokenizerPath = Paths.get(modelPath, "tokenizer.json");
if (tokenizerPath.toFile().exists()) {
tokenizer = HuggingFaceTokenizer.newInstance(tokenizerPath);
log.info("Tokenizer loaded from: {}", tokenizerPath);
} else {
log.warn("Tokenizer file not found at: {}, using default tokenizer", tokenizerPath);
}
} catch (Exception e) {
log.error("Failed to load tokenizer", e);
}
}
public long[] encode(String text) {
if (tokenizer == null) {
return new long[0];
}
Encoding encoding = tokenizer.encode(text);
2026-02-25 10:33:28 +08:00
return Arrays.stream(encoding.getIds()).toArray();
}
public String decode(long[] ids) {
if (tokenizer == null) {
return Arrays.toString(ids);
}
// 过滤特殊token
long[] filteredIds = Arrays.stream(ids)
.filter(id -> id > 0 && id < 151936) // GLM token范围
.toArray();
2026-02-25 10:33:28 +08:00
return tokenizer.decode(filteredIds);
}
public String decodeWithPrompt(long[] ids, String prompt) {
String decoded = decode(ids);
// 移除prompt部分
if (decoded.startsWith(prompt)) {
return decoded.substring(prompt.length()).trim();
}
return decoded;
}
public int getVocabSize() {
2026-02-25 10:33:28 +08:00
// return tokenizer != null ? tokenizer.getVocabularySize() : 0;
return tokenizer != null ? tokenizer.getMaxLength() : 0;
}
public boolean isAvailable() {
return tokenizer != null;
}
}