75 lines
2.1 KiB
Java
75 lines
2.1 KiB
Java
|
|
package com.example.glmocr.tokenizer;
|
||
|
|
|
||
|
|
import ai.djl.huggingface.tokenizers.Encoding;
|
||
|
|
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||
|
|
import lombok.extern.slf4j.Slf4j;
|
||
|
|
|
||
|
|
import java.io.IOException;
|
||
|
|
import java.nio.file.Path;
|
||
|
|
import java.nio.file.Paths;
|
||
|
|
import java.util.Arrays;
|
||
|
|
|
||
|
|
@Slf4j
|
||
|
|
public class TokenizerService {
|
||
|
|
|
||
|
|
private HuggingFaceTokenizer tokenizer;
|
||
|
|
private final String modelPath;
|
||
|
|
|
||
|
|
public TokenizerService(String modelPath) {
|
||
|
|
this.modelPath = modelPath;
|
||
|
|
init();
|
||
|
|
}
|
||
|
|
|
||
|
|
private void init() {
|
||
|
|
try {
|
||
|
|
Path tokenizerPath = Paths.get(modelPath, "tokenizer.json");
|
||
|
|
if (tokenizerPath.toFile().exists()) {
|
||
|
|
tokenizer = HuggingFaceTokenizer.newInstance(tokenizerPath);
|
||
|
|
log.info("Tokenizer loaded from: {}", tokenizerPath);
|
||
|
|
} else {
|
||
|
|
log.warn("Tokenizer file not found at: {}, using default tokenizer", tokenizerPath);
|
||
|
|
}
|
||
|
|
} catch (Exception e) {
|
||
|
|
log.error("Failed to load tokenizer", e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
public long[] encode(String text) {
|
||
|
|
if (tokenizer == null) {
|
||
|
|
return new long[0];
|
||
|
|
}
|
||
|
|
Encoding encoding = tokenizer.encode(text);
|
||
|
|
return Arrays.stream(encoding.getIds()).asLongStream().toArray();
|
||
|
|
}
|
||
|
|
|
||
|
|
public String decode(long[] ids) {
|
||
|
|
if (tokenizer == null) {
|
||
|
|
return Arrays.toString(ids);
|
||
|
|
}
|
||
|
|
// 过滤特殊token
|
||
|
|
long[] filteredIds = Arrays.stream(ids)
|
||
|
|
.filter(id -> id > 0 && id < 151936) // GLM token范围
|
||
|
|
.toArray();
|
||
|
|
|
||
|
|
String[] tokens = tokenizer.decode(filteredIds);
|
||
|
|
return String.join("", tokens);
|
||
|
|
}
|
||
|
|
|
||
|
|
public String decodeWithPrompt(long[] ids, String prompt) {
|
||
|
|
String decoded = decode(ids);
|
||
|
|
// 移除prompt部分
|
||
|
|
if (decoded.startsWith(prompt)) {
|
||
|
|
return decoded.substring(prompt.length()).trim();
|
||
|
|
}
|
||
|
|
return decoded;
|
||
|
|
}
|
||
|
|
|
||
|
|
public int getVocabSize() {
|
||
|
|
return tokenizer != null ? tokenizer.getVocabularySize() : 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
public boolean isAvailable() {
|
||
|
|
return tokenizer != null;
|
||
|
|
}
|
||
|
|
}
|