KNN算法

This commit is contained in:
黎润豪 2023-10-28 20:33:54 +08:00
parent ba25e9e4c5
commit 43a94cf678
6 changed files with 425 additions and 48 deletions

View File

@ -0,0 +1,23 @@
package io.lroyia.bean;
import lombok.Data;
/**
* 分类模型数据
*
* @author lroyia
* @since 2023/10/28 13:51
**/
@Data
public class ClassifyModelData<ResultT> {
/**
* 特征数据
*/
private double[] features;
/**
* 分类结果
*/
private ResultT result;
}

View File

@ -0,0 +1,30 @@
package io.lroyia.model;
/**
* 分类模型
*
* @author lroyia
* @since 2023/10/28 13:49
**/
public interface ClassifyModel<ResultT> {
/**
* 分类预测
* @param featureData 预测数据
* @return 预测结果
* @author lroyia
* @since 2023年10月28日 13:50:45
*/
ResultT prediction(double[] featureData);
/**
* 测试模型
*
* @param testData 测试数据
* @param testResult 测试结果
* @return 准确率
* @author lroyia
* @since 2023年10月28日 13:50:00
*/
double test(double[][] testData, ResultT[] testResult);
}

View File

@ -0,0 +1,154 @@
package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import lombok.Data;
import java.util.*;
/**
* KD树
*
* @author lroyia
* @since 2023/10/28 13:24
**/
@Data
public class KDTree<ResultT> {
/**
* 根节点
*/
private KDTreeNode<ResultT> rootNode;
public KDTree(double[][] trainData, ResultT[] result) {
build(trainData, result);
}
/**
* 构建树
*
* @param trainData 训练数据
* @param result 训练结果
* @author lroyia
* @since 2023年10月28日 16:15:37
*/
private void build(double[][] trainData, ResultT[] result) {
// 构建模型数据
List<ClassifyModelData<ResultT>> modelDataList = new ArrayList<>(trainData.length);
for (int i = 0; i < trainData.length; i++) {
double[] each = trainData[i];
ClassifyModelData<ResultT> data = new ClassifyModelData<>();
data.setFeatures(each);
data.setResult(result[i]);
modelDataList.add(data);
}
// 计算维度分散度排行
int[] discretenessIndex = getDiscretenessIndex(modelDataList);
// 构建节点
buildNode(null, modelDataList, 0, discretenessIndex);
}
/**
* 构建节点
*
* @param parentNode 父节点
* @param data 构建数据
* @param featureIndex 特征排序索引
* @param featureIndexArray 特征排序清单
* @return 构建节点
* @author lroyia
* @since 2023年10月28日 17:23:48
*/
private KDTreeNode<ResultT> buildNode(KDTreeNode<ResultT> parentNode, List<ClassifyModelData<ResultT>> data,
int featureIndex, int[] featureIndexArray) {
int fIdx = featureIndexArray[featureIndex];
data.sort(Comparator.comparingDouble(each -> each.getFeatures()[fIdx]));
int index = data.size() / 2;
ClassifyModelData<ResultT> currentData = data.get(index);
KDTreeNode<ResultT> left = null;
KDTreeNode<ResultT> right = null;
KDTreeNode<ResultT> node = new KDTreeNode<>(fIdx, parentNode, currentData, null, null);
if (parentNode == null) {
rootNode = node;
}
int nextIndex = featureIndex == featureIndexArray.length - 1 ? 0 : (featureIndex + 1);
if (index > 0) {
left = buildNode(node, new ArrayList<>(data.subList(0, index)), nextIndex, featureIndexArray);
}
if (index < data.size() - 1) {
right = buildNode(node, new ArrayList<>(data.subList(index + 1, data.size())), nextIndex, featureIndexArray);
}
node.setLeftNode(left);
node.setRightNode(right);
return node;
}
/**
* 获取维度分散度排序索引
*
* @param modelDataList 模型数据
* @return 排序
* @author lroyia
* @since 2023年10月28日 16:49:35
*/
private int[] getDiscretenessIndex(List<ClassifyModelData<ResultT>> modelDataList) {
int resultLength = modelDataList.get(0).getFeatures().length;
int[] result = new int[resultLength];
double[] maxArray = new double[resultLength];
double[] minArray = new double[resultLength];
for (int i = 0; i < resultLength; i++) {
maxArray[i] = Double.MIN_VALUE;
minArray[i] = Double.MAX_VALUE;
}
for (ClassifyModelData<ResultT> each : modelDataList) {
double[] features = each.getFeatures();
for (int i = 0; i < resultLength; i++) {
double data = features[i];
if (data > maxArray[i]) {
maxArray[i] = data;
}
if (data < minArray[i]) {
minArray[i] = Double.MAX_VALUE;
}
}
}
Set<Integer> contentSet = new HashSet<>();
int index = 0;
double tempDiff = Double.MIN_VALUE;
for (int i = 0; i < resultLength; i++) {
for (int j = 0; j < resultLength; j++) {
if (contentSet.contains(j)) {
continue;
}
double diff = Math.abs(maxArray[i] - minArray[j]);
if (diff > tempDiff) {
index = j;
tempDiff = diff;
}
}
contentSet.add(index);
result[i] = index;
}
return result;
}
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) {
KDTreeNode<ResultT> currentNode = rootNode;
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
nodePath.add(currentNode);
while (currentNode != null && (currentNode.getLeftNode() != null || currentNode.getRightNode() != null)) {
int featureIndex = currentNode.getFeatureIndex();
double[] features = data.getFeatures();
ClassifyModelData<ResultT> currentData = currentNode.getData();
double[] currentFeatures = currentData.getFeatures();
if (features[featureIndex] < currentFeatures[featureIndex]) {
currentNode = currentNode.getLeftNode();
} else {
currentNode = currentNode.getRightNode();
}
if (currentNode != null) {
nodePath.add(currentNode);
}
}
}
}

View File

@ -0,0 +1,41 @@
package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import lombok.AllArgsConstructor;
import lombok.Data;
/**
* kd树节点
*
* @author lroyia
* @since 2023/10/28 17:29
**/
@Data
@AllArgsConstructor
public class KDTreeNode<ResultT> {
/**
* 特征维度索引
*/
private int featureIndex;
/**
* 父节点
*/
private KDTreeNode<ResultT> parentNode;
/**
* 节点数据
*/
private ClassifyModelData<ResultT> data;
/**
* 左节点
*/
private KDTreeNode<ResultT> leftNode;
/**
* 右节点
*/
private KDTreeNode<ResultT> rightNode;
}

View File

@ -0,0 +1,131 @@
package io.lroyia.model.knn;
import io.lroyia.model.ClassifyModel;
import io.lroyia.util.DataUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.*;
/**
* @author lroyia
* @since 2023/10/28 12:27
**/
@Data
public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
/**
* k值
*/
private int kValue;
/**
* 训练数据比较数据
*/
private double[][] trainData;
/**
* 维度最大值记录用于归一化
*/
private double[] maxArray;
/**
* 维度最小值记录用于归一化
*/
private double[] minArray;
/**
* 训练结果
*/
private ResultT[] trainResult;
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) {
this.kValue = kValue;
this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) {
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]);
this.maxArray[i] = toOneResult.getMax();
this.minArray[i] = toOneResult.getMin();
nTrainData[i] = toOneResult.getResult();
}
this.trainData = nTrainData;
}
/**
* 结果预测
*
* @param featureData 特征数据
* @return 预测结果
* @author lroyia
* @since 2023年10月28日 12:48:00
*/
@Override
public ResultT prediction(double[] featureData) {
double[] tempData = new double[featureData.length];
for (int i = 0; i < featureData.length; i++) {
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);
}
featureData = tempData;
List<CalcResult> calcResultList = new ArrayList<>(trainData.length);
for (int i = 0; i < trainData.length; i++) {
calcResultList.add(new CalcResult(i, DataUtil.calcDistance(trainData[i], featureData)));
}
calcResultList.sort(Comparator.comparing(CalcResult::getDistance));
calcResultList.subList(0, kValue);
Map<ResultT, Integer> countMap = new HashMap<>(kValue);
for (CalcResult each : calcResultList) {
ResultT key = trainResult[each.index];
Integer count = countMap.computeIfAbsent(key, k -> 0);
countMap.put(key, count + 1);
}
ResultT result = null;
int count = 0;
for (Map.Entry<ResultT, Integer> each : countMap.entrySet()) {
if (each.getValue() > count) {
count = each.getValue();
result = each.getKey();
}
}
return result;
}
/**
* 模型测试
*
* @param testData 测试数据
* @param testResult 测试结果
* @return 评估准确率
* @author lroyia
* @since 2023年10月28日 12:57:54
*/
@Override
public double test(double[][] testData, ResultT[] testResult) {
double hitCount = 0;
for (int i = 0; i < testData.length; i++) {
double[] eachTestData = testData[i];
ResultT prediction = prediction(eachTestData);
if (prediction.equals(testResult[i])) {
hitCount++;
}
}
return hitCount / testResult.length;
}
@Data
@AllArgsConstructor
private static class CalcResult {
/**
* 位置索引
*/
private int index;
/**
* 距离
*/
private double distance;
}
}

View File

@ -2,6 +2,8 @@ package io.lroyia.util;
import io.lroyia.bean.EntCalInfo;
import io.lroyia.bean.EntInfo;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
@ -84,70 +86,66 @@ public abstract class DataUtil {
return result;
}
/**
* 归一
*
* @param list list
* @return 归一结果
* @author lroyia
* @since 2023年10月24日 10:36:08
*/
public static List<EntCalInfo> toOne(List<EntCalInfo> list) {
List<Double> estDateList = new ArrayList<>(list.size());
List<Double> entTypeList = new ArrayList<>(list.size());
List<Double> regCapList = new ArrayList<>(list.size());
List<Double> industryCoList = new ArrayList<>(list.size());
List<Double> regStateList = new ArrayList<>(list.size());
for (EntCalInfo each : list) {
estDateList.add(each.getEstDate());
entTypeList.add(each.getEntType());
regCapList.add(each.getRegCap());
industryCoList.add(each.getIndustryCo());
regStateList.add(each.getRegState());
}
estDateList = columnToOne(estDateList);
entTypeList = columnToOne(entTypeList);
regCapList = columnToOne(regCapList);
industryCoList = columnToOne(industryCoList);
regStateList = columnToOne(regStateList);
for (int i = 0; i < list.size(); i++) {
EntCalInfo each = list.get(i);
each.setEstDate(estDateList.get(i))
.setEntType(entTypeList.get(i))
.setRegCap(regCapList.get(i))
.setIndustryCo(industryCoList.get(i))
.setRegState(regStateList.get(i));
}
return list;
}
/**
* 列归一
*
* @param list 列数据
* @param arr 列数据
* @return 归一
* @author lroyia
* @since 2023年10月24日 10:27:05
*/
private static List<Double> columnToOne(List<Double> list) {
public static ToOneResult columnToOne(double[] arr) {
double max = Double.MIN_VALUE;
double min = Double.MAX_VALUE;
for (Double each : list) {
for (Double each : arr) {
max = Math.max(each, max);
min = Math.min(each, min);
}
double divisor = max - min;
List<Double> result = new ArrayList<>(list.size());
for (Double each : list) {
result.add((each - min) / divisor);
double[] result = new double[arr.length];
for (int i = 0; i < arr.length; i++) {
double each = arr[i];
result[i] = ((each - min) / divisor);
}
return result;
return new ToOneResult(max, min, result);
}
public static void main(String[] args) {
List<EntCalInfo> one = toOne(getAllCalInfo());
for (EntCalInfo each : one) {
System.out.println(each);
/**
* 计算距离
*
* @param point1 坐标点1
* @param point2 坐标点2
* @return 计算结果
* @author lroyia
* @since 2023年10月28日 18:30:04
*/
public static double calcDistance(double[] point1, double[] point2) {
double stdDev = 0;
for (int i1 = 0; i1 < point1.length; i1++) {
stdDev += Math.pow(point1[i1] - point2[i1], 2);
}
return Math.sqrt(stdDev);
}
/**
* 归一结果
*/
@Getter
@AllArgsConstructor
public static class ToOneResult {
/**
* 最大值
*/
private double max;
/**
* 最小值
*/
private double min;
/**
* 归一结果
*/
private double[] result;
}
}