补充标准化

This commit is contained in:
黎润豪 2023-10-29 19:22:24 +08:00
parent 0c5afedc41
commit 7602982178
2 changed files with 66 additions and 14 deletions

View File

@ -56,9 +56,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) {
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]);
this.maxArray[i] = toOneResult.getMax();
this.minArray[i] = toOneResult.getMin();
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
this.maxArray[i] = maxMin.getMax();
this.minArray[i] = maxMin.getMin();
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
nTrainData[i] = toOneResult.getResult();
}

View File

@ -86,6 +86,7 @@ public abstract class DataUtil {
return result;
}
/**
* 列归一
*
@ -95,12 +96,9 @@ public abstract class DataUtil {
* @since 2023年10月24日 10:27:05
*/
public static ToOneResult columnToOne(double[] arr) {
double max = Double.MIN_VALUE;
double min = Double.MAX_VALUE;
for (Double each : arr) {
max = Math.max(each, max);
min = Math.min(each, min);
}
MaxMin maxMin = getMaxMin(arr);
double max = maxMin.max;
double min = maxMin.min;
double divisor = max - min;
double[] result = new double[arr.length];
for (int i = 0; i < arr.length; i++) {
@ -110,6 +108,50 @@ public abstract class DataUtil {
return new ToOneResult(max, min, result);
}
/**
* 获取最大最小值
*
* @param arr 数组
* @return 最大最小值
* @author lroyia
* @since 2023年10月29日 16:09:38
*/
public static MaxMin getMaxMin(double[] arr) {
double max = Double.MIN_VALUE;
double min = Double.MAX_VALUE;
for (Double each : arr) {
max = Math.max(each, max);
min = Math.min(each, min);
}
return new MaxMin(max, min);
}
/**
* 列归一
*
* @param arr 列数据
* @return 归一
* @author lroyia
* @since 2023年10月24日 10:27:05
*/
public static double[] standardData(double[] arr) {
double[] result = new double[arr.length];
double sum = 0;
for (double each : arr) {
sum += each;
}
double avg = sum / arr.length;
double variance = 0;
for (double each : arr) {
variance += Math.pow(each - avg, 2);
}
variance = Math.sqrt(variance);
for (int i = 0; i < arr.length; i++) {
result[i] = (arr[i] - avg) / variance;
}
return result;
}
/**
* 计算距离
*
@ -127,12 +169,9 @@ public abstract class DataUtil {
return Math.sqrt(stdDev);
}
/**
* 归一结果
*/
@Getter
@AllArgsConstructor
public static class ToOneResult {
public static class MaxMin {
/**
* 最大值
*/
@ -142,10 +181,22 @@ public abstract class DataUtil {
* 最小值
*/
private double min;
}
/**
* 归一结果
*/
@Getter
public static class ToOneResult extends MaxMin {
/**
* 归一结果
*/
private double[] result;
private final double[] result;
public ToOneResult(double max, double min, double[] result) {
super(max, min);
this.result = result;
}
}
}