diff --git a/src/main/java/io/lroyia/bean/ClassifyModelData.java b/src/main/java/io/lroyia/bean/ClassifyModelData.java new file mode 100644 index 0000000..25a8544 --- /dev/null +++ b/src/main/java/io/lroyia/bean/ClassifyModelData.java @@ -0,0 +1,23 @@ +package io.lroyia.bean; + +import lombok.Data; + +/** + * 分类模型数据 + * + * @author lroyia + * @since 2023/10/28 13:51 + **/ +@Data +public class ClassifyModelData { + + /** + * 特征数据 + */ + private double[] features; + + /** + * 分类结果 + */ + private ResultT result; +} diff --git a/src/main/java/io/lroyia/model/ClassifyModel.java b/src/main/java/io/lroyia/model/ClassifyModel.java new file mode 100644 index 0000000..2a72a71 --- /dev/null +++ b/src/main/java/io/lroyia/model/ClassifyModel.java @@ -0,0 +1,30 @@ +package io.lroyia.model; + +/** + * 分类模型 + * + * @author lroyia + * @since 2023/10/28 13:49 + **/ +public interface ClassifyModel { + + /** + * 分类预测 + * @param featureData 预测数据 + * @return 预测结果 + * @author lroyia + * @since 2023年10月28日 13:50:45 + */ + ResultT prediction(double[] featureData); + + /** + * 测试模型 + * + * @param testData 测试数据 + * @param testResult 测试结果 + * @return 准确率 + * @author lroyia + * @since 2023年10月28日 13:50:00 + */ + double test(double[][] testData, ResultT[] testResult); +} diff --git a/src/main/java/io/lroyia/model/knn/KDTree.java b/src/main/java/io/lroyia/model/knn/KDTree.java new file mode 100644 index 0000000..2b35a6a --- /dev/null +++ b/src/main/java/io/lroyia/model/knn/KDTree.java @@ -0,0 +1,154 @@ +package io.lroyia.model.knn; + +import io.lroyia.bean.ClassifyModelData; +import lombok.Data; + +import java.util.*; + +/** + * KD树 + * + * @author lroyia + * @since 2023/10/28 13:24 + **/ +@Data +public class KDTree { + + /** + * 根节点 + */ + private KDTreeNode rootNode; + + public KDTree(double[][] trainData, ResultT[] result) { + build(trainData, result); + } + + /** + * 构建树 + * + * @param trainData 训练数据 + * @param result 训练结果 + * @author lroyia + * @since 2023年10月28日 16:15:37 + */ + private void build(double[][] trainData, ResultT[] result) { + // 构建模型数据 + List> modelDataList = new ArrayList<>(trainData.length); + for (int i = 0; i < trainData.length; i++) { + double[] each = trainData[i]; + ClassifyModelData data = new ClassifyModelData<>(); + data.setFeatures(each); + data.setResult(result[i]); + modelDataList.add(data); + } + // 计算维度分散度排行 + int[] discretenessIndex = getDiscretenessIndex(modelDataList); + // 构建节点 + buildNode(null, modelDataList, 0, discretenessIndex); + } + + /** + * 构建节点 + * + * @param parentNode 父节点 + * @param data 构建数据 + * @param featureIndex 特征排序索引 + * @param featureIndexArray 特征排序清单 + * @return 构建节点 + * @author lroyia + * @since 2023年10月28日 17:23:48 + */ + private KDTreeNode buildNode(KDTreeNode parentNode, List> data, + int featureIndex, int[] featureIndexArray) { + int fIdx = featureIndexArray[featureIndex]; + data.sort(Comparator.comparingDouble(each -> each.getFeatures()[fIdx])); + int index = data.size() / 2; + ClassifyModelData currentData = data.get(index); + KDTreeNode left = null; + KDTreeNode right = null; + KDTreeNode node = new KDTreeNode<>(fIdx, parentNode, currentData, null, null); + if (parentNode == null) { + rootNode = node; + } + int nextIndex = featureIndex == featureIndexArray.length - 1 ? 0 : (featureIndex + 1); + if (index > 0) { + left = buildNode(node, new ArrayList<>(data.subList(0, index)), nextIndex, featureIndexArray); + } + if (index < data.size() - 1) { + right = buildNode(node, new ArrayList<>(data.subList(index + 1, data.size())), nextIndex, featureIndexArray); + } + node.setLeftNode(left); + node.setRightNode(right); + return node; + } + + + /** + * 获取维度分散度排序索引 + * + * @param modelDataList 模型数据 + * @return 排序 + * @author lroyia + * @since 2023年10月28日 16:49:35 + */ + private int[] getDiscretenessIndex(List> modelDataList) { + int resultLength = modelDataList.get(0).getFeatures().length; + int[] result = new int[resultLength]; + double[] maxArray = new double[resultLength]; + double[] minArray = new double[resultLength]; + for (int i = 0; i < resultLength; i++) { + maxArray[i] = Double.MIN_VALUE; + minArray[i] = Double.MAX_VALUE; + } + for (ClassifyModelData each : modelDataList) { + double[] features = each.getFeatures(); + for (int i = 0; i < resultLength; i++) { + double data = features[i]; + if (data > maxArray[i]) { + maxArray[i] = data; + } + if (data < minArray[i]) { + minArray[i] = Double.MAX_VALUE; + } + } + } + Set contentSet = new HashSet<>(); + int index = 0; + double tempDiff = Double.MIN_VALUE; + for (int i = 0; i < resultLength; i++) { + for (int j = 0; j < resultLength; j++) { + if (contentSet.contains(j)) { + continue; + } + double diff = Math.abs(maxArray[i] - minArray[j]); + if (diff > tempDiff) { + index = j; + tempDiff = diff; + } + } + contentSet.add(index); + result[i] = index; + } + return result; + } + + public KDTreeNode getNearestNode(ClassifyModelData data) { + KDTreeNode currentNode = rootNode; + List> nodePath = new ArrayList<>(); + nodePath.add(currentNode); + while (currentNode != null && (currentNode.getLeftNode() != null || currentNode.getRightNode() != null)) { + int featureIndex = currentNode.getFeatureIndex(); + double[] features = data.getFeatures(); + ClassifyModelData currentData = currentNode.getData(); + double[] currentFeatures = currentData.getFeatures(); + if (features[featureIndex] < currentFeatures[featureIndex]) { + currentNode = currentNode.getLeftNode(); + } else { + currentNode = currentNode.getRightNode(); + } + if (currentNode != null) { + nodePath.add(currentNode); + } + } + } +} diff --git a/src/main/java/io/lroyia/model/knn/KDTreeNode.java b/src/main/java/io/lroyia/model/knn/KDTreeNode.java new file mode 100644 index 0000000..9d0afac --- /dev/null +++ b/src/main/java/io/lroyia/model/knn/KDTreeNode.java @@ -0,0 +1,41 @@ +package io.lroyia.model.knn; + +import io.lroyia.bean.ClassifyModelData; +import lombok.AllArgsConstructor; +import lombok.Data; + +/** + * kd树节点 + * + * @author lroyia + * @since 2023/10/28 17:29 + **/ +@Data +@AllArgsConstructor +public class KDTreeNode { + + /** + * 特征维度索引 + */ + private int featureIndex; + + /** + * 父节点 + */ + private KDTreeNode parentNode; + + /** + * 节点数据 + */ + private ClassifyModelData data; + + /** + * 左节点 + */ + private KDTreeNode leftNode; + + /** + * 右节点 + */ + private KDTreeNode rightNode; +} diff --git a/src/main/java/io/lroyia/model/knn/KNNModel.java b/src/main/java/io/lroyia/model/knn/KNNModel.java new file mode 100644 index 0000000..9b043d7 --- /dev/null +++ b/src/main/java/io/lroyia/model/knn/KNNModel.java @@ -0,0 +1,131 @@ +package io.lroyia.model.knn; + +import io.lroyia.model.ClassifyModel; +import io.lroyia.util.DataUtil; +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.util.*; + +/** + * @author lroyia + * @since 2023/10/28 12:27 + **/ +@Data +public class KNNModel implements ClassifyModel { + + /** + * k值 + */ + private int kValue; + + /** + * 训练数据(比较数据) + */ + private double[][] trainData; + + /** + * 维度最大值记录(用于归一化) + */ + private double[] maxArray; + + /** + * 维度最小值记录(用于归一化) + */ + private double[] minArray; + + /** + * 训练结果 + */ + private ResultT[] trainResult; + + public KNNModel(int kValue, double[][] trainData, ResultT[] trainResult) { + this.kValue = kValue; + this.trainResult = trainResult; + double[][] nTrainData = new double[trainData.length][]; + this.maxArray = new double[trainData.length]; + for (int i = 0; i < trainData.length; i++) { + DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]); + this.maxArray[i] = toOneResult.getMax(); + this.minArray[i] = toOneResult.getMin(); + nTrainData[i] = toOneResult.getResult(); + + } + this.trainData = nTrainData; + } + + /** + * 结果预测 + * + * @param featureData 特征数据 + * @return 预测结果 + * @author lroyia + * @since 2023年10月28日 12:48:00 + */ + @Override + public ResultT prediction(double[] featureData) { + double[] tempData = new double[featureData.length]; + for (int i = 0; i < featureData.length; i++) { + tempData[i] = (featureData[i] - minArray[i]) / (maxArray[i] - minArray[i]); + } + featureData = tempData; + List calcResultList = new ArrayList<>(trainData.length); + for (int i = 0; i < trainData.length; i++) { + calcResultList.add(new CalcResult(i, DataUtil.calcDistance(trainData[i], featureData))); + } + calcResultList.sort(Comparator.comparing(CalcResult::getDistance)); + calcResultList.subList(0, kValue); + Map countMap = new HashMap<>(kValue); + for (CalcResult each : calcResultList) { + ResultT key = trainResult[each.index]; + Integer count = countMap.computeIfAbsent(key, k -> 0); + countMap.put(key, count + 1); + } + ResultT result = null; + int count = 0; + for (Map.Entry each : countMap.entrySet()) { + if (each.getValue() > count) { + count = each.getValue(); + result = each.getKey(); + } + } + return result; + } + + /** + * 模型测试 + * + * @param testData 测试数据 + * @param testResult 测试结果 + * @return 评估准确率 + * @author lroyia + * @since 2023年10月28日 12:57:54 + */ + @Override + public double test(double[][] testData, ResultT[] testResult) { + double hitCount = 0; + for (int i = 0; i < testData.length; i++) { + double[] eachTestData = testData[i]; + ResultT prediction = prediction(eachTestData); + if (prediction.equals(testResult[i])) { + hitCount++; + } + } + return hitCount / testResult.length; + } + + @Data + @AllArgsConstructor + private static class CalcResult { + + /** + * 位置索引 + */ + private int index; + + /** + * 距离 + */ + private double distance; + } +} diff --git a/src/main/java/io/lroyia/util/DataUtil.java b/src/main/java/io/lroyia/util/DataUtil.java index 09e8e3e..6317e1e 100644 --- a/src/main/java/io/lroyia/util/DataUtil.java +++ b/src/main/java/io/lroyia/util/DataUtil.java @@ -2,6 +2,8 @@ package io.lroyia.util; import io.lroyia.bean.EntCalInfo; import io.lroyia.bean.EntInfo; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; @@ -84,70 +86,66 @@ public abstract class DataUtil { return result; } - /** - * 归一 - * - * @param list list - * @return 归一结果 - * @author lroyia - * @since 2023年10月24日 10:36:08 - */ - public static List toOne(List list) { - List estDateList = new ArrayList<>(list.size()); - List entTypeList = new ArrayList<>(list.size()); - List regCapList = new ArrayList<>(list.size()); - List industryCoList = new ArrayList<>(list.size()); - List regStateList = new ArrayList<>(list.size()); - for (EntCalInfo each : list) { - estDateList.add(each.getEstDate()); - entTypeList.add(each.getEntType()); - regCapList.add(each.getRegCap()); - industryCoList.add(each.getIndustryCo()); - regStateList.add(each.getRegState()); - } - estDateList = columnToOne(estDateList); - entTypeList = columnToOne(entTypeList); - regCapList = columnToOne(regCapList); - industryCoList = columnToOne(industryCoList); - regStateList = columnToOne(regStateList); - for (int i = 0; i < list.size(); i++) { - EntCalInfo each = list.get(i); - each.setEstDate(estDateList.get(i)) - .setEntType(entTypeList.get(i)) - .setRegCap(regCapList.get(i)) - .setIndustryCo(industryCoList.get(i)) - .setRegState(regStateList.get(i)); - } - return list; - } - /** * 列归一 * - * @param list 列数据 + * @param arr 列数据 * @return 归一 * @author lroyia * @since 2023年10月24日 10:27:05 */ - private static List columnToOne(List list) { + public static ToOneResult columnToOne(double[] arr) { double max = Double.MIN_VALUE; double min = Double.MAX_VALUE; - for (Double each : list) { + for (Double each : arr) { max = Math.max(each, max); min = Math.min(each, min); } double divisor = max - min; - List result = new ArrayList<>(list.size()); - for (Double each : list) { - result.add((each - min) / divisor); + double[] result = new double[arr.length]; + for (int i = 0; i < arr.length; i++) { + double each = arr[i]; + result[i] = ((each - min) / divisor); } - return result; + return new ToOneResult(max, min, result); } - public static void main(String[] args) { - List one = toOne(getAllCalInfo()); - for (EntCalInfo each : one) { - System.out.println(each); + /** + * 计算距离 + * + * @param point1 坐标点1 + * @param point2 坐标点2 + * @return 计算结果 + * @author lroyia + * @since 2023年10月28日 18:30:04 + */ + public static double calcDistance(double[] point1, double[] point2) { + double stdDev = 0; + for (int i1 = 0; i1 < point1.length; i1++) { + stdDev += Math.pow(point1[i1] - point2[i1], 2); } + return Math.sqrt(stdDev); + } + + /** + * 归一结果 + */ + @Getter + @AllArgsConstructor + public static class ToOneResult { + /** + * 最大值 + */ + private double max; + + /** + * 最小值 + */ + private double min; + + /** + * 归一结果 + */ + private double[] result; } }