补充kd树优化
This commit is contained in:
parent
43a94cf678
commit
0c5afedc41
|
|
@ -1,6 +1,7 @@
|
|||
package io.lroyia.bean;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
/**
|
||||
* 分类模型数据
|
||||
|
|
@ -9,6 +10,7 @@ import lombok.Data;
|
|||
* @since 2023/10/28 13:51
|
||||
**/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class ClassifyModelData<ResultT> {
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package io.lroyia.model.knn;
|
||||
|
||||
import io.lroyia.bean.ClassifyModelData;
|
||||
import io.lroyia.util.DataUtil;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.*;
|
||||
|
|
@ -132,7 +133,16 @@ public class KDTree<ResultT> {
|
|||
return result;
|
||||
}
|
||||
|
||||
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) {
|
||||
/**
|
||||
* 最近邻点查找
|
||||
*
|
||||
* @param data 搜索数据
|
||||
* @return 最近邻点
|
||||
* @author lroyia
|
||||
* @since 2023年10月29日 15:33:25
|
||||
*/
|
||||
public KDTreeNode<ResultT> searchNearestNode(ClassifyModelData<ResultT> data) {
|
||||
// 创建搜索路径
|
||||
KDTreeNode<ResultT> currentNode = rootNode;
|
||||
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
|
||||
nodePath.add(currentNode);
|
||||
|
|
@ -150,5 +160,45 @@ public class KDTree<ResultT> {
|
|||
nodePath.add(currentNode);
|
||||
}
|
||||
}
|
||||
// 回溯路径找出最优点
|
||||
double[] dataFeatures = data.getFeatures();
|
||||
double distance = Double.MAX_VALUE;
|
||||
Set<KDTreeNode<ResultT>> searchedSet = new HashSet<>();
|
||||
KDTreeNode<ResultT> lastNode = nodePath.get(nodePath.size() - 1);
|
||||
while (!nodePath.isEmpty()) {
|
||||
KDTreeNode<ResultT> tempNode = nodePath.remove(nodePath.size() - 1);
|
||||
searchedSet.add(tempNode);
|
||||
double d = DataUtil.calcDistance(tempNode.getData().getFeatures(), dataFeatures);
|
||||
if (d > distance) {
|
||||
continue;
|
||||
}
|
||||
distance = d;
|
||||
lastNode = tempNode;
|
||||
if (tempNode.getParentNode() != null) {
|
||||
int featureIndex = tempNode.getParentNode().getFeatureIndex();
|
||||
double pf = tempNode.getParentNode().getData().getFeatures()[featureIndex];
|
||||
// 判断是否跨界
|
||||
if (dataFeatures[featureIndex] > pf ? (dataFeatures[featureIndex] - distance < pf) :
|
||||
(dataFeatures[featureIndex] + distance > pf)) {
|
||||
KDTreeNode<ResultT> addNode;
|
||||
if (tempNode.getParentNode().getLeftNode().equals(tempNode)) {
|
||||
addNode = tempNode.getParentNode().getRightNode();
|
||||
} else {
|
||||
addNode = tempNode.getParentNode().getLeftNode();
|
||||
}
|
||||
if (searchedSet.contains(addNode)) {
|
||||
continue;
|
||||
}
|
||||
if (nodePath.isEmpty()) {
|
||||
nodePath.add(addNode);
|
||||
continue;
|
||||
}
|
||||
tempNode = nodePath.get(nodePath.size() - 1);
|
||||
nodePath.set(nodePath.size() - 1, addNode);
|
||||
nodePath.add(tempNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
return lastNode;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package io.lroyia.model.knn;
|
||||
|
||||
import io.lroyia.bean.ClassifyModelData;
|
||||
import io.lroyia.model.ClassifyModel;
|
||||
import io.lroyia.util.DataUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
|
@ -39,7 +40,17 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
|||
*/
|
||||
private ResultT[] trainResult;
|
||||
|
||||
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) {
|
||||
/**
|
||||
* 使用kd树
|
||||
*/
|
||||
private boolean useKdTree;
|
||||
|
||||
/**
|
||||
* kd树
|
||||
*/
|
||||
private KDTree<ResultT> kdTree;
|
||||
|
||||
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult, boolean useKdTree) {
|
||||
this.kValue = kValue;
|
||||
this.trainResult = trainResult;
|
||||
double[][] nTrainData = new double[trainData.length][];
|
||||
|
|
@ -52,6 +63,14 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
|||
|
||||
}
|
||||
this.trainData = nTrainData;
|
||||
setUseKdTree(useKdTree);
|
||||
}
|
||||
|
||||
public void setUseKdTree(boolean useKdTree) {
|
||||
this.useKdTree = useKdTree;
|
||||
if (useKdTree) {
|
||||
kdTree = new KDTree<>(trainData, trainResult);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -64,6 +83,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
|||
*/
|
||||
@Override
|
||||
public ResultT prediction(double[] featureData) {
|
||||
if (useKdTree) {
|
||||
KDTreeNode<ResultT> resultTKDTreeNode = kdTree.searchNearestNode(new ClassifyModelData<ResultT>().setFeatures(featureData));
|
||||
return resultTKDTreeNode.getData().getResult();
|
||||
}
|
||||
double[] tempData = new double[featureData.length];
|
||||
for (int i = 0; i < featureData.length; i++) {
|
||||
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue