补充kd树优化

This commit is contained in:
黎润豪 2023-10-29 15:37:38 +08:00
parent 43a94cf678
commit 0c5afedc41
3 changed files with 77 additions and 2 deletions

View File

@ -1,6 +1,7 @@
package io.lroyia.bean; package io.lroyia.bean;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
/** /**
* 分类模型数据 * 分类模型数据
@ -9,6 +10,7 @@ import lombok.Data;
* @since 2023/10/28 13:51 * @since 2023/10/28 13:51
**/ **/
@Data @Data
@Accessors(chain = true)
public class ClassifyModelData<ResultT> { public class ClassifyModelData<ResultT> {
/** /**

View File

@ -1,6 +1,7 @@
package io.lroyia.model.knn; package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData; import io.lroyia.bean.ClassifyModelData;
import io.lroyia.util.DataUtil;
import lombok.Data; import lombok.Data;
import java.util.*; import java.util.*;
@ -132,7 +133,16 @@ public class KDTree<ResultT> {
return result; return result;
} }
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) { /**
* 最近邻点查找
*
* @param data 搜索数据
* @return 最近邻点
* @author lroyia
* @since 2023年10月29日 15:33:25
*/
public KDTreeNode<ResultT> searchNearestNode(ClassifyModelData<ResultT> data) {
// 创建搜索路径
KDTreeNode<ResultT> currentNode = rootNode; KDTreeNode<ResultT> currentNode = rootNode;
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>(); List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
nodePath.add(currentNode); nodePath.add(currentNode);
@ -150,5 +160,45 @@ public class KDTree<ResultT> {
nodePath.add(currentNode); nodePath.add(currentNode);
} }
} }
// 回溯路径找出最优点
double[] dataFeatures = data.getFeatures();
double distance = Double.MAX_VALUE;
Set<KDTreeNode<ResultT>> searchedSet = new HashSet<>();
KDTreeNode<ResultT> lastNode = nodePath.get(nodePath.size() - 1);
while (!nodePath.isEmpty()) {
KDTreeNode<ResultT> tempNode = nodePath.remove(nodePath.size() - 1);
searchedSet.add(tempNode);
double d = DataUtil.calcDistance(tempNode.getData().getFeatures(), dataFeatures);
if (d > distance) {
continue;
}
distance = d;
lastNode = tempNode;
if (tempNode.getParentNode() != null) {
int featureIndex = tempNode.getParentNode().getFeatureIndex();
double pf = tempNode.getParentNode().getData().getFeatures()[featureIndex];
// 判断是否跨界
if (dataFeatures[featureIndex] > pf ? (dataFeatures[featureIndex] - distance < pf) :
(dataFeatures[featureIndex] + distance > pf)) {
KDTreeNode<ResultT> addNode;
if (tempNode.getParentNode().getLeftNode().equals(tempNode)) {
addNode = tempNode.getParentNode().getRightNode();
} else {
addNode = tempNode.getParentNode().getLeftNode();
}
if (searchedSet.contains(addNode)) {
continue;
}
if (nodePath.isEmpty()) {
nodePath.add(addNode);
continue;
}
tempNode = nodePath.get(nodePath.size() - 1);
nodePath.set(nodePath.size() - 1, addNode);
nodePath.add(tempNode);
}
}
}
return lastNode;
} }
} }

View File

@ -1,5 +1,6 @@
package io.lroyia.model.knn; package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import io.lroyia.model.ClassifyModel; import io.lroyia.model.ClassifyModel;
import io.lroyia.util.DataUtil; import io.lroyia.util.DataUtil;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -39,7 +40,17 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
*/ */
private ResultT[] trainResult; private ResultT[] trainResult;
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) { /**
* 使用kd树
*/
private boolean useKdTree;
/**
* kd树
*/
private KDTree<ResultT> kdTree;
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult, boolean useKdTree) {
this.kValue = kValue; this.kValue = kValue;
this.trainResult = trainResult; this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][]; double[][] nTrainData = new double[trainData.length][];
@ -52,6 +63,14 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
} }
this.trainData = nTrainData; this.trainData = nTrainData;
setUseKdTree(useKdTree);
}
public void setUseKdTree(boolean useKdTree) {
this.useKdTree = useKdTree;
if (useKdTree) {
kdTree = new KDTree<>(trainData, trainResult);
}
} }
/** /**
@ -64,6 +83,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
*/ */
@Override @Override
public ResultT prediction(double[] featureData) { public ResultT prediction(double[] featureData) {
if (useKdTree) {
KDTreeNode<ResultT> resultTKDTreeNode = kdTree.searchNearestNode(new ClassifyModelData<ResultT>().setFeatures(featureData));
return resultTKDTreeNode.getData().getResult();
}
double[] tempData = new double[featureData.length]; double[] tempData = new double[featureData.length];
for (int i = 0; i < featureData.length; i++) { for (int i = 0; i < featureData.length; i++) {
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]); tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);