KNN算法
This commit is contained in:
parent
ba25e9e4c5
commit
43a94cf678
|
|
@ -0,0 +1,23 @@
|
|||
package io.lroyia.bean;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 分类模型数据
|
||||
*
|
||||
* @author lroyia
|
||||
* @since 2023/10/28 13:51
|
||||
**/
|
||||
@Data
|
||||
public class ClassifyModelData<ResultT> {
|
||||
|
||||
/**
|
||||
* 特征数据
|
||||
*/
|
||||
private double[] features;
|
||||
|
||||
/**
|
||||
* 分类结果
|
||||
*/
|
||||
private ResultT result;
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
package io.lroyia.model;
|
||||
|
||||
/**
|
||||
* 分类模型
|
||||
*
|
||||
* @author lroyia
|
||||
* @since 2023/10/28 13:49
|
||||
**/
|
||||
public interface ClassifyModel<ResultT> {
|
||||
|
||||
/**
|
||||
* 分类预测
|
||||
* @param featureData 预测数据
|
||||
* @return 预测结果
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 13:50:45
|
||||
*/
|
||||
ResultT prediction(double[] featureData);
|
||||
|
||||
/**
|
||||
* 测试模型
|
||||
*
|
||||
* @param testData 测试数据
|
||||
* @param testResult 测试结果
|
||||
* @return 准确率
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 13:50:00
|
||||
*/
|
||||
double test(double[][] testData, ResultT[] testResult);
|
||||
}
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
package io.lroyia.model.knn;
|
||||
|
||||
import io.lroyia.bean.ClassifyModelData;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* KD树
|
||||
*
|
||||
* @author lroyia
|
||||
* @since 2023/10/28 13:24
|
||||
**/
|
||||
@Data
|
||||
public class KDTree<ResultT> {
|
||||
|
||||
/**
|
||||
* 根节点
|
||||
*/
|
||||
private KDTreeNode<ResultT> rootNode;
|
||||
|
||||
public KDTree(double[][] trainData, ResultT[] result) {
|
||||
build(trainData, result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建树
|
||||
*
|
||||
* @param trainData 训练数据
|
||||
* @param result 训练结果
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 16:15:37
|
||||
*/
|
||||
private void build(double[][] trainData, ResultT[] result) {
|
||||
// 构建模型数据
|
||||
List<ClassifyModelData<ResultT>> modelDataList = new ArrayList<>(trainData.length);
|
||||
for (int i = 0; i < trainData.length; i++) {
|
||||
double[] each = trainData[i];
|
||||
ClassifyModelData<ResultT> data = new ClassifyModelData<>();
|
||||
data.setFeatures(each);
|
||||
data.setResult(result[i]);
|
||||
modelDataList.add(data);
|
||||
}
|
||||
// 计算维度分散度排行
|
||||
int[] discretenessIndex = getDiscretenessIndex(modelDataList);
|
||||
// 构建节点
|
||||
buildNode(null, modelDataList, 0, discretenessIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建节点
|
||||
*
|
||||
* @param parentNode 父节点
|
||||
* @param data 构建数据
|
||||
* @param featureIndex 特征排序索引
|
||||
* @param featureIndexArray 特征排序清单
|
||||
* @return 构建节点
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 17:23:48
|
||||
*/
|
||||
private KDTreeNode<ResultT> buildNode(KDTreeNode<ResultT> parentNode, List<ClassifyModelData<ResultT>> data,
|
||||
int featureIndex, int[] featureIndexArray) {
|
||||
int fIdx = featureIndexArray[featureIndex];
|
||||
data.sort(Comparator.comparingDouble(each -> each.getFeatures()[fIdx]));
|
||||
int index = data.size() / 2;
|
||||
ClassifyModelData<ResultT> currentData = data.get(index);
|
||||
KDTreeNode<ResultT> left = null;
|
||||
KDTreeNode<ResultT> right = null;
|
||||
KDTreeNode<ResultT> node = new KDTreeNode<>(fIdx, parentNode, currentData, null, null);
|
||||
if (parentNode == null) {
|
||||
rootNode = node;
|
||||
}
|
||||
int nextIndex = featureIndex == featureIndexArray.length - 1 ? 0 : (featureIndex + 1);
|
||||
if (index > 0) {
|
||||
left = buildNode(node, new ArrayList<>(data.subList(0, index)), nextIndex, featureIndexArray);
|
||||
}
|
||||
if (index < data.size() - 1) {
|
||||
right = buildNode(node, new ArrayList<>(data.subList(index + 1, data.size())), nextIndex, featureIndexArray);
|
||||
}
|
||||
node.setLeftNode(left);
|
||||
node.setRightNode(right);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取维度分散度排序索引
|
||||
*
|
||||
* @param modelDataList 模型数据
|
||||
* @return 排序
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 16:49:35
|
||||
*/
|
||||
private int[] getDiscretenessIndex(List<ClassifyModelData<ResultT>> modelDataList) {
|
||||
int resultLength = modelDataList.get(0).getFeatures().length;
|
||||
int[] result = new int[resultLength];
|
||||
double[] maxArray = new double[resultLength];
|
||||
double[] minArray = new double[resultLength];
|
||||
for (int i = 0; i < resultLength; i++) {
|
||||
maxArray[i] = Double.MIN_VALUE;
|
||||
minArray[i] = Double.MAX_VALUE;
|
||||
}
|
||||
for (ClassifyModelData<ResultT> each : modelDataList) {
|
||||
double[] features = each.getFeatures();
|
||||
for (int i = 0; i < resultLength; i++) {
|
||||
double data = features[i];
|
||||
if (data > maxArray[i]) {
|
||||
maxArray[i] = data;
|
||||
}
|
||||
if (data < minArray[i]) {
|
||||
minArray[i] = Double.MAX_VALUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
Set<Integer> contentSet = new HashSet<>();
|
||||
int index = 0;
|
||||
double tempDiff = Double.MIN_VALUE;
|
||||
for (int i = 0; i < resultLength; i++) {
|
||||
for (int j = 0; j < resultLength; j++) {
|
||||
if (contentSet.contains(j)) {
|
||||
continue;
|
||||
}
|
||||
double diff = Math.abs(maxArray[i] - minArray[j]);
|
||||
if (diff > tempDiff) {
|
||||
index = j;
|
||||
tempDiff = diff;
|
||||
}
|
||||
}
|
||||
contentSet.add(index);
|
||||
result[i] = index;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public KDTreeNode<ResultT> getNearestNode(ClassifyModelData<ResultT> data) {
|
||||
KDTreeNode<ResultT> currentNode = rootNode;
|
||||
List<KDTreeNode<ResultT>> nodePath = new ArrayList<>();
|
||||
nodePath.add(currentNode);
|
||||
while (currentNode != null && (currentNode.getLeftNode() != null || currentNode.getRightNode() != null)) {
|
||||
int featureIndex = currentNode.getFeatureIndex();
|
||||
double[] features = data.getFeatures();
|
||||
ClassifyModelData<ResultT> currentData = currentNode.getData();
|
||||
double[] currentFeatures = currentData.getFeatures();
|
||||
if (features[featureIndex] < currentFeatures[featureIndex]) {
|
||||
currentNode = currentNode.getLeftNode();
|
||||
} else {
|
||||
currentNode = currentNode.getRightNode();
|
||||
}
|
||||
if (currentNode != null) {
|
||||
nodePath.add(currentNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package io.lroyia.model.knn;
|
||||
|
||||
import io.lroyia.bean.ClassifyModelData;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* kd树节点
|
||||
*
|
||||
* @author lroyia
|
||||
* @since 2023/10/28 17:29
|
||||
**/
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class KDTreeNode<ResultT> {
|
||||
|
||||
/**
|
||||
* 特征维度索引
|
||||
*/
|
||||
private int featureIndex;
|
||||
|
||||
/**
|
||||
* 父节点
|
||||
*/
|
||||
private KDTreeNode<ResultT> parentNode;
|
||||
|
||||
/**
|
||||
* 节点数据
|
||||
*/
|
||||
private ClassifyModelData<ResultT> data;
|
||||
|
||||
/**
|
||||
* 左节点
|
||||
*/
|
||||
private KDTreeNode<ResultT> leftNode;
|
||||
|
||||
/**
|
||||
* 右节点
|
||||
*/
|
||||
private KDTreeNode<ResultT> rightNode;
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
package io.lroyia.model.knn;
|
||||
|
||||
import io.lroyia.model.ClassifyModel;
|
||||
import io.lroyia.util.DataUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lroyia
|
||||
* @since 2023/10/28 12:27
|
||||
**/
|
||||
@Data
|
||||
public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
||||
|
||||
/**
|
||||
* k值
|
||||
*/
|
||||
private int kValue;
|
||||
|
||||
/**
|
||||
* 训练数据(比较数据)
|
||||
*/
|
||||
private double[][] trainData;
|
||||
|
||||
/**
|
||||
* 维度最大值记录(用于归一化)
|
||||
*/
|
||||
private double[] maxArray;
|
||||
|
||||
/**
|
||||
* 维度最小值记录(用于归一化)
|
||||
*/
|
||||
private double[] minArray;
|
||||
|
||||
/**
|
||||
* 训练结果
|
||||
*/
|
||||
private ResultT[] trainResult;
|
||||
|
||||
public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) {
|
||||
this.kValue = kValue;
|
||||
this.trainResult = trainResult;
|
||||
double[][] nTrainData = new double[trainData.length][];
|
||||
this.maxArray = new double[trainData.length];
|
||||
for (int i = 0; i < trainData.length; i++) {
|
||||
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]);
|
||||
this.maxArray[i] = toOneResult.getMax();
|
||||
this.minArray[i] = toOneResult.getMin();
|
||||
nTrainData[i] = toOneResult.getResult();
|
||||
|
||||
}
|
||||
this.trainData = nTrainData;
|
||||
}
|
||||
|
||||
/**
|
||||
* 结果预测
|
||||
*
|
||||
* @param featureData 特征数据
|
||||
* @return 预测结果
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 12:48:00
|
||||
*/
|
||||
@Override
|
||||
public ResultT prediction(double[] featureData) {
|
||||
double[] tempData = new double[featureData.length];
|
||||
for (int i = 0; i < featureData.length; i++) {
|
||||
tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]);
|
||||
}
|
||||
featureData = tempData;
|
||||
List<CalcResult> calcResultList = new ArrayList<>(trainData.length);
|
||||
for (int i = 0; i < trainData.length; i++) {
|
||||
calcResultList.add(new CalcResult(i, DataUtil.calcDistance(trainData[i], featureData)));
|
||||
}
|
||||
calcResultList.sort(Comparator.comparing(CalcResult::getDistance));
|
||||
calcResultList.subList(0, kValue);
|
||||
Map<ResultT, Integer> countMap = new HashMap<>(kValue);
|
||||
for (CalcResult each : calcResultList) {
|
||||
ResultT key = trainResult[each.index];
|
||||
Integer count = countMap.computeIfAbsent(key, k -> 0);
|
||||
countMap.put(key, count + 1);
|
||||
}
|
||||
ResultT result = null;
|
||||
int count = 0;
|
||||
for (Map.Entry<ResultT, Integer> each : countMap.entrySet()) {
|
||||
if (each.getValue() > count) {
|
||||
count = each.getValue();
|
||||
result = each.getKey();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 模型测试
|
||||
*
|
||||
* @param testData 测试数据
|
||||
* @param testResult 测试结果
|
||||
* @return 评估准确率
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 12:57:54
|
||||
*/
|
||||
@Override
|
||||
public double test(double[][] testData, ResultT[] testResult) {
|
||||
double hitCount = 0;
|
||||
for (int i = 0; i < testData.length; i++) {
|
||||
double[] eachTestData = testData[i];
|
||||
ResultT prediction = prediction(eachTestData);
|
||||
if (prediction.equals(testResult[i])) {
|
||||
hitCount++;
|
||||
}
|
||||
}
|
||||
return hitCount / testResult.length;
|
||||
}
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
private static class CalcResult {
|
||||
|
||||
/**
|
||||
* 位置索引
|
||||
*/
|
||||
private int index;
|
||||
|
||||
/**
|
||||
* 距离
|
||||
*/
|
||||
private double distance;
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ package io.lroyia.util;
|
|||
|
||||
import io.lroyia.bean.EntCalInfo;
|
||||
import io.lroyia.bean.EntInfo;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
|
@ -84,70 +86,66 @@ public abstract class DataUtil {
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 归一
|
||||
*
|
||||
* @param list list
|
||||
* @return 归一结果
|
||||
* @author lroyia
|
||||
* @since 2023年10月24日 10:36:08
|
||||
*/
|
||||
public static List<EntCalInfo> toOne(List<EntCalInfo> list) {
|
||||
List<Double> estDateList = new ArrayList<>(list.size());
|
||||
List<Double> entTypeList = new ArrayList<>(list.size());
|
||||
List<Double> regCapList = new ArrayList<>(list.size());
|
||||
List<Double> industryCoList = new ArrayList<>(list.size());
|
||||
List<Double> regStateList = new ArrayList<>(list.size());
|
||||
for (EntCalInfo each : list) {
|
||||
estDateList.add(each.getEstDate());
|
||||
entTypeList.add(each.getEntType());
|
||||
regCapList.add(each.getRegCap());
|
||||
industryCoList.add(each.getIndustryCo());
|
||||
regStateList.add(each.getRegState());
|
||||
}
|
||||
estDateList = columnToOne(estDateList);
|
||||
entTypeList = columnToOne(entTypeList);
|
||||
regCapList = columnToOne(regCapList);
|
||||
industryCoList = columnToOne(industryCoList);
|
||||
regStateList = columnToOne(regStateList);
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
EntCalInfo each = list.get(i);
|
||||
each.setEstDate(estDateList.get(i))
|
||||
.setEntType(entTypeList.get(i))
|
||||
.setRegCap(regCapList.get(i))
|
||||
.setIndustryCo(industryCoList.get(i))
|
||||
.setRegState(regStateList.get(i));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
* 列归一
|
||||
*
|
||||
* @param list 列数据
|
||||
* @param arr 列数据
|
||||
* @return 归一
|
||||
* @author lroyia
|
||||
* @since 2023年10月24日 10:27:05
|
||||
*/
|
||||
private static List<Double> columnToOne(List<Double> list) {
|
||||
public static ToOneResult columnToOne(double[] arr) {
|
||||
double max = Double.MIN_VALUE;
|
||||
double min = Double.MAX_VALUE;
|
||||
for (Double each : list) {
|
||||
for (Double each : arr) {
|
||||
max = Math.max(each, max);
|
||||
min = Math.min(each, min);
|
||||
}
|
||||
double divisor = max - min;
|
||||
List<Double> result = new ArrayList<>(list.size());
|
||||
for (Double each : list) {
|
||||
result.add((each - min) / divisor);
|
||||
double[] result = new double[arr.length];
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
double each = arr[i];
|
||||
result[i] = ((each - min) / divisor);
|
||||
}
|
||||
return result;
|
||||
return new ToOneResult(max, min, result);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
List<EntCalInfo> one = toOne(getAllCalInfo());
|
||||
for (EntCalInfo each : one) {
|
||||
System.out.println(each);
|
||||
/**
|
||||
* 计算距离
|
||||
*
|
||||
* @param point1 坐标点1
|
||||
* @param point2 坐标点2
|
||||
* @return 计算结果
|
||||
* @author lroyia
|
||||
* @since 2023年10月28日 18:30:04
|
||||
*/
|
||||
public static double calcDistance(double[] point1, double[] point2) {
|
||||
double stdDev = 0;
|
||||
for (int i1 = 0; i1 < point1.length; i1++) {
|
||||
stdDev += Math.pow(point1[i1] - point2[i1], 2);
|
||||
}
|
||||
return Math.sqrt(stdDev);
|
||||
}
|
||||
|
||||
/**
|
||||
* 归一结果
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class ToOneResult {
|
||||
/**
|
||||
* 最大值
|
||||
*/
|
||||
private double max;
|
||||
|
||||
/**
|
||||
* 最小值
|
||||
*/
|
||||
private double min;
|
||||
|
||||
/**
|
||||
* 归一结果
|
||||
*/
|
||||
private double[] result;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue