package com.example.glmocr.tokenizer; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import lombok.extern.slf4j.Slf4j; import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; @Slf4j public class TokenizerService { private HuggingFaceTokenizer tokenizer; private final String modelPath; public TokenizerService(String modelPath) { this.modelPath = modelPath; init(); } private void init() { try { Path tokenizerPath = Paths.get(modelPath, "tokenizer.json"); if (tokenizerPath.toFile().exists()) { tokenizer = HuggingFaceTokenizer.newInstance(tokenizerPath); log.info("Tokenizer loaded from: {}", tokenizerPath); } else { log.warn("Tokenizer file not found at: {}, using default tokenizer", tokenizerPath); } } catch (Exception e) { log.error("Failed to load tokenizer", e); } } public long[] encode(String text) { if (tokenizer == null) { return new long[0]; } Encoding encoding = tokenizer.encode(text); return Arrays.stream(encoding.getIds()).toArray(); } public String decode(long[] ids) { if (tokenizer == null) { return Arrays.toString(ids); } // 过滤特殊token long[] filteredIds = Arrays.stream(ids) .filter(id -> id > 0 && id < 151936) // GLM token范围 .toArray(); return tokenizer.decode(filteredIds); } public String decodeWithPrompt(long[] ids, String prompt) { String decoded = decode(ids); // 移除prompt部分 if (decoded.startsWith(prompt)) { return decoded.substring(prompt.length()).trim(); } return decoded; } public int getVocabSize() { // return tokenizer != null ? tokenizer.getVocabularySize() : 0; return tokenizer != null ? tokenizer.getMaxLength() : 0; } public boolean isAvailable() { return tokenizer != null; } }