KNN算法
This commit is contained in:
parent
ba25e9e4c5
commit
43a94cf678
|
|
@ -0,0 +1,23 @@
|
||||||
|
package io.lroyia.bean;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分类模型数据
|
||||||
|
*
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023/10/28 13:51
|
||||||
|
**/
|
||||||
|
@Data
|
||||||
|
public class ClassifyModelData<ResultT> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 特征数据
|
||||||
|
*/
|
||||||
|
private double[] features;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分类结果
|
||||||
|
*/
|
||||||
|
private ResultT result;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
package io.lroyia.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分类模型
|
||||||
|
*
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023/10/28 13:49
|
||||||
|
**/
|
||||||
|
public interface ClassifyModel<ResultT> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分类预测
|
||||||
|
* @param featureData 预测数据
|
||||||
|
* @return 预测结果
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 13:50:45
|
||||||
|
*/
|
||||||
|
ResultT prediction(double[] featureData);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 测试模型
|
||||||
|
*
|
||||||
|
* @param testData 测试数据
|
||||||
|
* @param testResult 测试结果
|
||||||
|
* @return 准确率
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 13:50:00
|
||||||
|
*/
|
||||||
|
double test(double[][] testData, ResultT[] testResult);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
package io.lroyia.model.knn;
|
||||||
|
|
||||||
|
import io.lroyia.bean.ClassifyModelData;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* KD树
|
||||||
|
*
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023/10/28 13:24
|
||||||
|
**/
|
||||||
|
@Data
|
||||||
|
public class KDTree<ResultT> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根节点
|
||||||
|
*/
|
||||||
|
private KDTreeNode<ResultT> rootNode;
|
||||||
|
|
||||||
|
public KDTree(double[][] trainData, ResultT[] result) {
|
||||||
|
build(trainData, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建树
|
||||||
|
*
|
||||||
|
* @param trainData 训练数据
|
||||||
|
* @param result 训练结果
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 16:15:37
|
||||||
|
*/
|
||||||
|
private void build(double[][] trainData, ResultT[] result) {
|
||||||
|
// 构建模型数据
|
||||||
|
List<ClassifyModelData<ResultT>> modelDataList = new ArrayList<>(trainData.length);
|
||||||
|
for (int i = 0; i < trainData.length; i++) {
|
||||||
|
double[] each = trainData[i];
|
||||||
|
ClassifyModelData<ResultT> data = new ClassifyModelData<>();
|
||||||
|
data.setFeatures(each);
|
||||||
|
data.setResult(result[i]);
|
||||||
|
modelDataList.add(data);
|
||||||
|
}
|
||||||
|
// 计算维度分散度排行
|
||||||
|
int[] discretenessIndex = getDiscretenessIndex(modelDataList);
|
||||||
|
// 构建节点
|
||||||
|
buildNode(null, modelDataList, 0, discretenessIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建节点
|
||||||
|
*
|
||||||
|
* @param parentNode 父节点
|
||||||
|
* @param data 构建数据
|
||||||
|
* @param featureIndex 特征排序索引
|
||||||
|
* @param featureIndexArray 特征排序清单
|
||||||
|
* @return 构建节点
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 17:23:48
|
||||||
|
*/
|
||||||
|
private KDTreeNode<ResultT> buildNode(KDTreeNode<ResultT> parentNode, List<ClassifyModelData<ResultT>> data,
|
||||||
|
int featureIndex, int[] featureIndexArray) {
|
||||||
|
int fIdx = featureIndexArray[featureIndex];
|
||||||
|
data.sort(Comparator.comparingDouble(each -> each.getFeatures()[fIdx]));
|
||||||
|
int index = data.size() / 2;
|
||||||
|
ClassifyModelData<ResultT> currentData = data.get(index);
|
||||||
|
KDTreeNode<ResultT> left = null;
|
||||||
|
KDTreeNode<ResultT> right = null;
|
||||||
|
KDTreeNode<ResultT> node = new KDTreeNode<>(fIdx, parentNode, currentData, null, null);
|
||||||
|
if (parentNode == null) {
|
||||||
|
rootNode = node;
|
||||||
|
}
|
||||||
|
int nextIndex = featureIndex == featureIndexArray.length - 1 ? 0 : (featureIndex + 1);
|
||||||
|
if (index > 0) {
|
||||||
|
left = buildNode(node, new ArrayList<>(data.subList(0, index)), nextIndex, featureIndexArray);
|
||||||
|
}
|
||||||
|
if (index < data.size() - 1) {
|
||||||
|
right = buildNode(node, new ArrayList<>(data.subList(index + 1, data.size())), nextIndex, featureIndexArray);
|
||||||
|
}
|
||||||
|
node.setLeftNode(left);
|
||||||
|
node.setRightNode(right);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取维度分散度排序索引
|
||||||
|
*
|
||||||
|
* @param modelDataList 模型数据
|
||||||
|
* @return 排序
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 16:49:35
|
||||||
|
*/
|
||||||
|
private int[] getDiscretenessIndex(List<ClassifyModelData<ResultT>> modelDataList) {
|
||||||
|
int resultLength = modelDataList.get(0).getFeatures().length;
|
||||||
|
int[] result = new int[resultLength];
|
||||||
|
double[] maxArray = new double[resultLength];
|
||||||
|
double[] minArray = new double[resultLength];
|
||||||
|
for (int i = 0; i < resultLength; i++) {
|
||||||
|
maxArray[i] = Double.MIN_VALUE;
|
||||||
|
minArray[i] = Double.MAX_VALUE;
|
||||||
|
}
|
||||||
|
for (ClassifyModelData<ResultT> each : modelDataList) {
|
||||||
|
double[] features = each.getFeatures();
|
||||||
|
for (int i = 0; i < resultLength; i++) {
|
||||||
|
double data = features[i];
|
||||||
|
if (data > maxArray[i]) {
|
||||||
|
maxArray[i] = data;
|
||||||
|
}
|
||||||
|
if (data < minArray[i]) {
|
||||||
|
minArray[i] = Double.MAX_VALUE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Set<Integer> contentSet = new HashSet<>();
|
||||||
|
int index = 0;
|
||||||
|
double tempDiff = Double.MIN_VALUE;
|
||||||
|
for (int i = 0; i < resultLength; i++) {
|
||||||
|
for (int j = 0; j < resultLength; j++) {
|
||||||
|
if (contentSet.contains(j)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
double diff = Math.abs(maxArray[i] - minArray[j]);
|
||||||
|
if (diff > tempDiff) {
|
||||||
|
index = j;
|
||||||
|
tempDiff = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentSet.add(index);
|
||||||
|
result[i] = index;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) {
|
||||||
|
KDTreeNode<ResultT> currentNode = rootNode;
|
||||||
|
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
|
||||||
|
nodePath.add(currentNode);
|
||||||
|
while (currentNode != null && (currentNode.getLeftNode() != null || currentNode.getRightNode() != null)) {
|
||||||
|
int featureIndex = currentNode.getFeatureIndex();
|
||||||
|
double[] features = data.getFeatures();
|
||||||
|
ClassifyModelData<ResultT> currentData = currentNode.getData();
|
||||||
|
double[] currentFeatures = currentData.getFeatures();
|
||||||
|
if (features[featureIndex] < currentFeatures[featureIndex]) {
|
||||||
|
currentNode = currentNode.getLeftNode();
|
||||||
|
} else {
|
||||||
|
currentNode = currentNode.getRightNode();
|
||||||
|
}
|
||||||
|
if (currentNode != null) {
|
||||||
|
nodePath.add(currentNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
package io.lroyia.model.knn;
|
||||||
|
|
||||||
|
import io.lroyia.bean.ClassifyModelData;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* kd树节点
|
||||||
|
*
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023/10/28 17:29
|
||||||
|
**/
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class KDTreeNode<ResultT> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 特征维度索引
|
||||||
|
*/
|
||||||
|
private int featureIndex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 父节点
|
||||||
|
*/
|
||||||
|
private KDTreeNode<ResultT> parentNode;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 节点数据
|
||||||
|
*/
|
||||||
|
private ClassifyModelData<ResultT> data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 左节点
|
||||||
|
*/
|
||||||
|
private KDTreeNode<ResultT> leftNode;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 右节点
|
||||||
|
*/
|
||||||
|
private KDTreeNode<ResultT> rightNode;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
package io.lroyia.model.knn;
|
||||||
|
|
||||||
|
import io.lroyia.model.ClassifyModel;
|
||||||
|
import io.lroyia.util.DataUtil;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023/10/28 12:27
|
||||||
|
**/
|
||||||
|
@Data
|
||||||
|
public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* k值
|
||||||
|
*/
|
||||||
|
private int kValue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练数据(比较数据)
|
||||||
|
*/
|
||||||
|
private double[][] trainData;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 维度最大值记录(用于归一化)
|
||||||
|
*/
|
||||||
|
private double[] maxArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 维度最小值记录(用于归一化)
|
||||||
|
*/
|
||||||
|
private double[] minArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练结果
|
||||||
|
*/
|
||||||
|
private ResultT[] trainResult;
|
||||||
|
|
||||||
|
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) {
|
||||||
|
this.kValue = kValue;
|
||||||
|
this.trainResult = trainResult;
|
||||||
|
double[][] nTrainData = new double[trainData.length][];
|
||||||
|
this.maxArray = new double[trainData.length];
|
||||||
|
for (int i = 0; i < trainData.length; i++) {
|
||||||
|
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]);
|
||||||
|
this.maxArray[i] = toOneResult.getMax();
|
||||||
|
this.minArray[i] = toOneResult.getMin();
|
||||||
|
nTrainData[i] = toOneResult.getResult();
|
||||||
|
|
||||||
|
}
|
||||||
|
this.trainData = nTrainData;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 结果预测
|
||||||
|
*
|
||||||
|
* @param featureData 特征数据
|
||||||
|
* @return 预测结果
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 12:48:00
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ResultT prediction(double[] featureData) {
|
||||||
|
double[] tempData = new double[featureData.length];
|
||||||
|
for (int i = 0; i < featureData.length; i++) {
|
||||||
|
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);
|
||||||
|
}
|
||||||
|
featureData = tempData;
|
||||||
|
List<CalcResult> calcResultList = new ArrayList<>(trainData.length);
|
||||||
|
for (int i = 0; i < trainData.length; i++) {
|
||||||
|
calcResultList.add(new CalcResult(i, DataUtil.calcDistance(trainData[i], featureData)));
|
||||||
|
}
|
||||||
|
calcResultList.sort(Comparator.comparing(CalcResult::getDistance));
|
||||||
|
calcResultList.subList(0, kValue);
|
||||||
|
Map<ResultT, Integer> countMap = new HashMap<>(kValue);
|
||||||
|
for (CalcResult each : calcResultList) {
|
||||||
|
ResultT key = trainResult[each.index];
|
||||||
|
Integer count = countMap.computeIfAbsent(key, k -> 0);
|
||||||
|
countMap.put(key, count + 1);
|
||||||
|
}
|
||||||
|
ResultT result = null;
|
||||||
|
int count = 0;
|
||||||
|
for (Map.Entry<ResultT, Integer> each : countMap.entrySet()) {
|
||||||
|
if (each.getValue() > count) {
|
||||||
|
count = each.getValue();
|
||||||
|
result = each.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模型测试
|
||||||
|
*
|
||||||
|
* @param testData 测试数据
|
||||||
|
* @param testResult 测试结果
|
||||||
|
* @return 评估准确率
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 12:57:54
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double test(double[][] testData, ResultT[] testResult) {
|
||||||
|
double hitCount = 0;
|
||||||
|
for (int i = 0; i < testData.length; i++) {
|
||||||
|
double[] eachTestData = testData[i];
|
||||||
|
ResultT prediction = prediction(eachTestData);
|
||||||
|
if (prediction.equals(testResult[i])) {
|
||||||
|
hitCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hitCount / testResult.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
private static class CalcResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 位置索引
|
||||||
|
*/
|
||||||
|
private int index;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 距离
|
||||||
|
*/
|
||||||
|
private double distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,6 +2,8 @@ package io.lroyia.util;
|
||||||
|
|
||||||
import io.lroyia.bean.EntCalInfo;
|
import io.lroyia.bean.EntCalInfo;
|
||||||
import io.lroyia.bean.EntInfo;
|
import io.lroyia.bean.EntInfo;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
import org.apache.commons.csv.CSVFormat;
|
import org.apache.commons.csv.CSVFormat;
|
||||||
import org.apache.commons.csv.CSVParser;
|
import org.apache.commons.csv.CSVParser;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
@ -84,70 +86,66 @@ public abstract class DataUtil {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 归一
|
|
||||||
*
|
|
||||||
* @param list list
|
|
||||||
* @return 归一结果
|
|
||||||
* @author lroyia
|
|
||||||
* @since 2023年10月24日 10:36:08
|
|
||||||
*/
|
|
||||||
public static List<EntCalInfo> toOne(List<EntCalInfo> list) {
|
|
||||||
List<Double> estDateList = new ArrayList<>(list.size());
|
|
||||||
List<Double> entTypeList = new ArrayList<>(list.size());
|
|
||||||
List<Double> regCapList = new ArrayList<>(list.size());
|
|
||||||
List<Double> industryCoList = new ArrayList<>(list.size());
|
|
||||||
List<Double> regStateList = new ArrayList<>(list.size());
|
|
||||||
for (EntCalInfo each : list) {
|
|
||||||
estDateList.add(each.getEstDate());
|
|
||||||
entTypeList.add(each.getEntType());
|
|
||||||
regCapList.add(each.getRegCap());
|
|
||||||
industryCoList.add(each.getIndustryCo());
|
|
||||||
regStateList.add(each.getRegState());
|
|
||||||
}
|
|
||||||
estDateList = columnToOne(estDateList);
|
|
||||||
entTypeList = columnToOne(entTypeList);
|
|
||||||
regCapList = columnToOne(regCapList);
|
|
||||||
industryCoList = columnToOne(industryCoList);
|
|
||||||
regStateList = columnToOne(regStateList);
|
|
||||||
for (int i = 0; i < list.size(); i++) {
|
|
||||||
EntCalInfo each = list.get(i);
|
|
||||||
each.setEstDate(estDateList.get(i))
|
|
||||||
.setEntType(entTypeList.get(i))
|
|
||||||
.setRegCap(regCapList.get(i))
|
|
||||||
.setIndustryCo(industryCoList.get(i))
|
|
||||||
.setRegState(regStateList.get(i));
|
|
||||||
}
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 列归一
|
* 列归一
|
||||||
*
|
*
|
||||||
* @param list 列数据
|
* @param arr 列数据
|
||||||
* @return 归一
|
* @return 归一
|
||||||
* @author lroyia
|
* @author lroyia
|
||||||
* @since 2023年10月24日 10:27:05
|
* @since 2023年10月24日 10:27:05
|
||||||
*/
|
*/
|
||||||
private static List<Double> columnToOne(List<Double> list) {
|
public static ToOneResult columnToOne(double[] arr) {
|
||||||
double max = Double.MIN_VALUE;
|
double max = Double.MIN_VALUE;
|
||||||
double min = Double.MAX_VALUE;
|
double min = Double.MAX_VALUE;
|
||||||
for (Double each : list) {
|
for (Double each : arr) {
|
||||||
max = Math.max(each, max);
|
max = Math.max(each, max);
|
||||||
min = Math.min(each, min);
|
min = Math.min(each, min);
|
||||||
}
|
}
|
||||||
double divisor = max - min;
|
double divisor = max - min;
|
||||||
List<Double> result = new ArrayList<>(list.size());
|
double[] result = new double[arr.length];
|
||||||
for (Double each : list) {
|
for (int i = 0; i < arr.length; i++) {
|
||||||
result.add((each - min) / divisor);
|
double each = arr[i];
|
||||||
|
result[i] = ((each - min) / divisor);
|
||||||
}
|
}
|
||||||
return result;
|
return new ToOneResult(max, min, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) {
|
/**
|
||||||
List<EntCalInfo> one = toOne(getAllCalInfo());
|
* 计算距离
|
||||||
for (EntCalInfo each : one) {
|
*
|
||||||
System.out.println(each);
|
* @param point1 坐标点1
|
||||||
|
* @param point2 坐标点2
|
||||||
|
* @return 计算结果
|
||||||
|
* @author lroyia
|
||||||
|
* @since 2023年10月28日 18:30:04
|
||||||
|
*/
|
||||||
|
public static double calcDistance(double[] point1, double[] point2) {
|
||||||
|
double stdDev = 0;
|
||||||
|
for (int i1 = 0; i1 < point1.length; i1++) {
|
||||||
|
stdDev += Math.pow(point1[i1] - point2[i1], 2);
|
||||||
}
|
}
|
||||||
|
return Math.sqrt(stdDev);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 归一结果
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public static class ToOneResult {
|
||||||
|
/**
|
||||||
|
* 最大值
|
||||||
|
*/
|
||||||
|
private double max;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 最小值
|
||||||
|
*/
|
||||||
|
private double min;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 归一结果
|
||||||
|
*/
|
||||||
|
private double[] result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue