补充kd树优化

This commit is contained in:
黎润豪 2023-10-29 15:37:38 +08:00
parent 43a94cf678
commit 0c5afedc41
3 changed files with 77 additions and 2 deletions

View File

@ -1,6 +1,7 @@
package io.lroyia.bean;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* 分类模型数据
@ -9,6 +10,7 @@ import lombok.Data;
* @since 2023/10/28 13:51
**/
@Data
@Accessors(chain = true)
public class ClassifyModelData<ResultT> {
/**

View File

@ -1,6 +1,7 @@
package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import io.lroyia.util.DataUtil;
import lombok.Data;
import java.util.*;
@ -132,7 +133,16 @@ public class KDTree<ResultT> {
return result;
}
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) {
/**
* 最近邻点查找
*
* @param data 搜索数据
* @return 最近邻点
* @author lroyia
* @since 2023年10月29日 15:33:25
*/
public KDTreeNode<ResultT> searchNearestNode(ClassifyModelData<ResultT> data) {
// 创建搜索路径
KDTreeNode<ResultT> currentNode = rootNode;
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
nodePath.add(currentNode);
@ -150,5 +160,45 @@ public class KDTree<ResultT> {
nodePath.add(currentNode);
}
}
// 回溯路径找出最优点
double[] dataFeatures = data.getFeatures();
double distance = Double.MAX_VALUE;
Set<KDTreeNode<ResultT>> searchedSet = new HashSet<>();
KDTreeNode<ResultT> lastNode = nodePath.get(nodePath.size() - 1);
while (!nodePath.isEmpty()) {
KDTreeNode<ResultT> tempNode = nodePath.remove(nodePath.size() - 1);
searchedSet.add(tempNode);
double d = DataUtil.calcDistance(tempNode.getData().getFeatures(), dataFeatures);
if (d > distance) {
continue;
}
distance = d;
lastNode = tempNode;
if (tempNode.getParentNode() != null) {
int featureIndex = tempNode.getParentNode().getFeatureIndex();
double pf = tempNode.getParentNode().getData().getFeatures()[featureIndex];
// 判断是否跨界
if (dataFeatures[featureIndex] > pf ? (dataFeatures[featureIndex] - distance < pf) :
(dataFeatures[featureIndex] + distance > pf)) {
KDTreeNode<ResultT> addNode;
if (tempNode.getParentNode().getLeftNode().equals(tempNode)) {
addNode = tempNode.getParentNode().getRightNode();
} else {
addNode = tempNode.getParentNode().getLeftNode();
}
if (searchedSet.contains(addNode)) {
continue;
}
if (nodePath.isEmpty()) {
nodePath.add(addNode);
continue;
}
tempNode = nodePath.get(nodePath.size() - 1);
nodePath.set(nodePath.size() - 1, addNode);
nodePath.add(tempNode);
}
}
}
return lastNode;
}
}

View File

@ -1,5 +1,6 @@
package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import io.lroyia.model.ClassifyModel;
import io.lroyia.util.DataUtil;
import lombok.AllArgsConstructor;
@ -39,7 +40,17 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
*/
private ResultT[] trainResult;
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) {
/**
* 使用kd树
*/
private boolean useKdTree;
/**
* kd树
*/
private KDTree<ResultT> kdTree;
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult, boolean useKdTree) {
this.kValue = kValue;
this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][];
@ -52,6 +63,14 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
}
this.trainData = nTrainData;
setUseKdTree(useKdTree);
}
public void setUseKdTree(boolean useKdTree) {
this.useKdTree = useKdTree;
if (useKdTree) {
kdTree = new KDTree<>(trainData, trainResult);
}
}
/**
@ -64,6 +83,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
*/
@Override
public ResultT prediction(double[] featureData) {
if (useKdTree) {
KDTreeNode<ResultT> resultTKDTreeNode = kdTree.searchNearestNode(new ClassifyModelData<ResultT>().setFeatures(featureData));
return resultTKDTreeNode.getData().getResult();
}
double[] tempData = new double[featureData.length];
for (int i = 0; i < featureData.length; i++) {
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);