diff --git a/src/main/java/io/lroyia/model/knn/KNNModel.java b/src/main/java/io/lroyia/model/knn/KNNModel.java index 073313a..970bc58 100644 --- a/src/main/java/io/lroyia/model/knn/KNNModel.java +++ b/src/main/java/io/lroyia/model/knn/KNNModel.java @@ -56,9 +56,10 @@ public class KNNModel implements ClassifyModel { double[][] nTrainData = new double[trainData.length][]; this.maxArray = new double[trainData.length]; for (int i = 0; i < trainData.length; i++) { - DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]); - this.maxArray[i] = toOneResult.getMax(); - this.minArray[i] = toOneResult.getMin(); + DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]); + this.maxArray[i] = maxMin.getMax(); + this.minArray[i] = maxMin.getMin(); + DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i])); nTrainData[i] = toOneResult.getResult(); } diff --git a/src/main/java/io/lroyia/util/DataUtil.java b/src/main/java/io/lroyia/util/DataUtil.java index 6317e1e..fc0a798 100644 --- a/src/main/java/io/lroyia/util/DataUtil.java +++ b/src/main/java/io/lroyia/util/DataUtil.java @@ -86,6 +86,7 @@ public abstract class DataUtil { return result; } + /** * 列归一 * @@ -95,12 +96,9 @@ public abstract class DataUtil { * @since 2023年10月24日 10:27:05 */ public static ToOneResult columnToOne(double[] arr) { - double max = Double.MIN_VALUE; - double min = Double.MAX_VALUE; - for (Double each : arr) { - max = Math.max(each, max); - min = Math.min(each, min); - } + MaxMin maxMin = getMaxMin(arr); + double max = maxMin.max; + double min = maxMin.min; double divisor = max - min; double[] result = new double[arr.length]; for (int i = 0; i < arr.length; i++) { @@ -110,6 +108,50 @@ public abstract class DataUtil { return new ToOneResult(max, min, result); } + /** + * 获取最大最小值 + * + * @param arr 数组 + * @return 最大最小值 + * @author lroyia + * @since 2023年10月29日 16:09:38 + */ + public static MaxMin getMaxMin(double[] arr) { + double max = Double.MIN_VALUE; + double min = Double.MAX_VALUE; + for (Double each : arr) { + max = Math.max(each, max); + min = Math.min(each, min); + } + return new MaxMin(max, min); + } + + /** + * 列归一 + * + * @param arr 列数据 + * @return 归一 + * @author lroyia + * @since 2023年10月24日 10:27:05 + */ + public static double[] standardData(double[] arr) { + double[] result = new double[arr.length]; + double sum = 0; + for (double each : arr) { + sum += each; + } + double avg = sum / arr.length; + double variance = 0; + for (double each : arr) { + variance += Math.pow(each - avg, 2); + } + variance = Math.sqrt(variance); + for (int i = 0; i < arr.length; i++) { + result[i] = (arr[i] - avg) / variance; + } + return result; + } + /** * 计算距离 * @@ -127,12 +169,9 @@ public abstract class DataUtil { return Math.sqrt(stdDev); } - /** - * 归一结果 - */ @Getter @AllArgsConstructor - public static class ToOneResult { + public static class MaxMin { /** * 最大值 */ @@ -142,10 +181,22 @@ public abstract class DataUtil { * 最小值 */ private double min; + } + + /** + * 归一结果 + */ + @Getter + public static class ToOneResult extends MaxMin { /** * 归一结果 */ - private double[] result; + private final double[] result; + + public ToOneResult(double max, double min, double[] result) { + super(max, min); + this.result = result; + } } }