补充标准化

This commit is contained in:
黎润豪 2023-10-29 19:22:24 +08:00
parent 0c5afedc41
commit 7602982178
2 changed files with 66 additions and 14 deletions

View File

@ -56,9 +56,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
double[][] nTrainData = new double[trainData.length][]; double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length]; this.maxArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) { for (int i = 0; i < trainData.length; i++) {
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]); DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
this.maxArray[i] = toOneResult.getMax(); this.maxArray[i] = maxMin.getMax();
this.minArray[i] = toOneResult.getMin(); this.minArray[i] = maxMin.getMin();
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
nTrainData[i] = toOneResult.getResult(); nTrainData[i] = toOneResult.getResult();
} }

View File

@ -86,6 +86,7 @@ public abstract class DataUtil {
return result; return result;
} }
/** /**
* 列归一 * 列归一
* *
@ -95,12 +96,9 @@ public abstract class DataUtil {
* @since 2023年10月24日 10:27:05 * @since 2023年10月24日 10:27:05
*/ */
public static ToOneResult columnToOne(double[] arr) { public static ToOneResult columnToOne(double[] arr) {
double max = Double.MIN_VALUE; MaxMin maxMin = getMaxMin(arr);
double min = Double.MAX_VALUE; double max = maxMin.max;
for (Double each : arr) { double min = maxMin.min;
max = Math.max(each, max);
min = Math.min(each, min);
}
double divisor = max - min; double divisor = max - min;
double[] result = new double[arr.length]; double[] result = new double[arr.length];
for (int i = 0; i < arr.length; i++) { for (int i = 0; i < arr.length; i++) {
@ -110,6 +108,50 @@ public abstract class DataUtil {
return new ToOneResult(max, min, result); return new ToOneResult(max, min, result);
} }
/**
* 获取最大最小值
*
* @param arr 数组
* @return 最大最小值
* @author lroyia
* @since 2023年10月29日 16:09:38
*/
public static MaxMin getMaxMin(double[] arr) {
double max = Double.MIN_VALUE;
double min = Double.MAX_VALUE;
for (Double each : arr) {
max = Math.max(each, max);
min = Math.min(each, min);
}
return new MaxMin(max, min);
}
/**
* 列归一
*
* @param arr 列数据
* @return 归一
* @author lroyia
* @since 2023年10月24日 10:27:05
*/
public static double[] standardData(double[] arr) {
double[] result = new double[arr.length];
double sum = 0;
for (double each : arr) {
sum += each;
}
double avg = sum / arr.length;
double variance = 0;
for (double each : arr) {
variance += Math.pow(each - avg, 2);
}
variance = Math.sqrt(variance);
for (int i = 0; i < arr.length; i++) {
result[i] = (arr[i] - avg) / variance;
}
return result;
}
/** /**
* 计算距离 * 计算距离
* *
@ -127,12 +169,9 @@ public abstract class DataUtil {
return Math.sqrt(stdDev); return Math.sqrt(stdDev);
} }
/**
* 归一结果
*/
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public static class ToOneResult { public static class MaxMin {
/** /**
* 最大值 * 最大值
*/ */
@ -142,10 +181,22 @@ public abstract class DataUtil {
* 最小值 * 最小值
*/ */
private double min; private double min;
}
/**
* 归一结果
*/
@Getter
public static class ToOneResult extends MaxMin {
/** /**
* 归一结果 * 归一结果
*/ */
private double[] result; private final double[] result;
public ToOneResult(double max, double min, double[] result) {
super(max, min);
this.result = result;
}
} }
} }