From 0c5afedc415d0244a8a85fbfd7f79ed6750593a0 Mon Sep 17 00:00:00 2001 From: lroyia <814876716@qq.com> Date: Sun, 29 Oct 2023 15:37:38 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85kd=E6=A0=91=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../io/lroyia/bean/ClassifyModelData.java | 2 + src/main/java/io/lroyia/model/knn/KDTree.java | 52 ++++++++++++++++++- .../java/io/lroyia/model/knn/KNNModel.java | 25 ++++++++- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/lroyia/bean/ClassifyModelData.java b/src/main/java/io/lroyia/bean/ClassifyModelData.java index 25a8544..1137a51 100644 --- a/src/main/java/io/lroyia/bean/ClassifyModelData.java +++ b/src/main/java/io/lroyia/bean/ClassifyModelData.java @@ -1,6 +1,7 @@ package io.lroyia.bean; import lombok.Data; +import lombok.experimental.Accessors; /** * 分类模型数据 @@ -9,6 +10,7 @@ import lombok.Data; * @since 2023/10/28 13:51 **/ @Data +@Accessors(chain = true) public class ClassifyModelData { /** diff --git a/src/main/java/io/lroyia/model/knn/KDTree.java b/src/main/java/io/lroyia/model/knn/KDTree.java index 2b35a6a..6ffd37e 100644 --- a/src/main/java/io/lroyia/model/knn/KDTree.java +++ b/src/main/java/io/lroyia/model/knn/KDTree.java @@ -1,6 +1,7 @@ package io.lroyia.model.knn; import io.lroyia.bean.ClassifyModelData; +import io.lroyia.util.DataUtil; import lombok.Data; import java.util.*; @@ -132,7 +133,16 @@ public class KDTree { return result; } - public KDTreeNode getNearestNode(ClassifyModelData data) { + /** + * 最近邻点查找 + * + * @param data 搜索数据 + * @return 最近邻点 + * @author lroyia + * @since 2023年10月29日 15:33:25 + */ + public KDTreeNode searchNearestNode(ClassifyModelData data) { + // 创建搜索路径 KDTreeNode currentNode = rootNode; List> nodePath = new ArrayList<>(); nodePath.add(currentNode); @@ -150,5 +160,45 @@ public class KDTree { nodePath.add(currentNode); } } + // 回溯路径找出最优点 + double[] dataFeatures = data.getFeatures(); + double distance = Double.MAX_VALUE; + Set> searchedSet = new HashSet<>(); + KDTreeNode lastNode = nodePath.get(nodePath.size() - 1); + while (!nodePath.isEmpty()) { + KDTreeNode tempNode = nodePath.remove(nodePath.size() - 1); + searchedSet.add(tempNode); + double d = DataUtil.calcDistance(tempNode.getData().getFeatures(), dataFeatures); + if (d > distance) { + continue; + } + distance = d; + lastNode = tempNode; + if (tempNode.getParentNode() != null) { + int featureIndex = tempNode.getParentNode().getFeatureIndex(); + double pf = tempNode.getParentNode().getData().getFeatures()[featureIndex]; + // 判断是否跨界 + if (dataFeatures[featureIndex] > pf ? (dataFeatures[featureIndex] - distance < pf) : + (dataFeatures[featureIndex] + distance > pf)) { + KDTreeNode addNode; + if (tempNode.getParentNode().getLeftNode().equals(tempNode)) { + addNode = tempNode.getParentNode().getRightNode(); + } else { + addNode = tempNode.getParentNode().getLeftNode(); + } + if (searchedSet.contains(addNode)) { + continue; + } + if (nodePath.isEmpty()) { + nodePath.add(addNode); + continue; + } + tempNode = nodePath.get(nodePath.size() - 1); + nodePath.set(nodePath.size() - 1, addNode); + nodePath.add(tempNode); + } + } + } + return lastNode; } } diff --git a/src/main/java/io/lroyia/model/knn/KNNModel.java b/src/main/java/io/lroyia/model/knn/KNNModel.java index 9b043d7..073313a 100644 --- a/src/main/java/io/lroyia/model/knn/KNNModel.java +++ b/src/main/java/io/lroyia/model/knn/KNNModel.java @@ -1,5 +1,6 @@ package io.lroyia.model.knn; +import io.lroyia.bean.ClassifyModelData; import io.lroyia.model.ClassifyModel; import io.lroyia.util.DataUtil; import lombok.AllArgsConstructor; @@ -39,7 +40,17 @@ public class KNNModel implements ClassifyModel { */ private ResultT[] trainResult; - public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) { + /** + * 使用kd树 + */ + private boolean useKdTree; + + /** + * kd树 + */ + private KDTree kdTree; + + public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult, boolean useKdTree) { this.kValue = kValue; this.trainResult = trainResult; double[][] nTrainData = new double[trainData.length][]; @@ -52,6 +63,14 @@ public class KNNModel implements ClassifyModel { } this.trainData = nTrainData; + setUseKdTree(useKdTree); + } + + public void setUseKdTree(boolean useKdTree) { + this.useKdTree = useKdTree; + if (useKdTree) { + kdTree = new KDTree<>(trainData, trainResult); + } } /** @@ -64,6 +83,10 @@ public class KNNModel implements ClassifyModel { */ @Override public ResultT prediction(double[] featureData) { + if (useKdTree) { + KDTreeNode resultTKDTreeNode = kdTree.searchNearestNode(new ClassifyModelData().setFeatures(featureData)); + return resultTKDTreeNode.getData().getResult(); + } double[] tempData = new double[featureData.length]; for (int i = 0; i < featureData.length; i++) { tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);