补充标准化
This commit is contained in:
parent
0c5afedc41
commit
7602982178
|
|
@ -56,9 +56,10 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
|||
double[][] nTrainData = new double[trainData.length][];
|
||||
this.maxArray = new double[trainData.length];
|
||||
for (int i = 0; i < trainData.length; i++) {
|
||||
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(trainData[i]);
|
||||
this.maxArray[i] = toOneResult.getMax();
|
||||
this.minArray[i] = toOneResult.getMin();
|
||||
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
|
||||
this.maxArray[i] = maxMin.getMax();
|
||||
this.minArray[i] = maxMin.getMin();
|
||||
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
|
||||
nTrainData[i] = toOneResult.getResult();
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ public abstract class DataUtil {
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 列归一
|
||||
*
|
||||
|
|
@ -95,12 +96,9 @@ public abstract class DataUtil {
|
|||
* @since 2023年10月24日 10:27:05
|
||||
*/
|
||||
public static ToOneResult columnToOne(double[] arr) {
|
||||
double max = Double.MIN_VALUE;
|
||||
double min = Double.MAX_VALUE;
|
||||
for (Double each : arr) {
|
||||
max = Math.max(each, max);
|
||||
min = Math.min(each, min);
|
||||
}
|
||||
MaxMin maxMin = getMaxMin(arr);
|
||||
double max = maxMin.max;
|
||||
double min = maxMin.min;
|
||||
double divisor = max - min;
|
||||
double[] result = new double[arr.length];
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
|
|
@ -110,6 +108,50 @@ public abstract class DataUtil {
|
|||
return new ToOneResult(max, min, result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取最大最小值
|
||||
*
|
||||
* @param arr 数组
|
||||
* @return 最大最小值
|
||||
* @author lroyia
|
||||
* @since 2023年10月29日 16:09:38
|
||||
*/
|
||||
public static MaxMin getMaxMin(double[] arr) {
|
||||
double max = Double.MIN_VALUE;
|
||||
double min = Double.MAX_VALUE;
|
||||
for (Double each : arr) {
|
||||
max = Math.max(each, max);
|
||||
min = Math.min(each, min);
|
||||
}
|
||||
return new MaxMin(max, min);
|
||||
}
|
||||
|
||||
/**
|
||||
* 列归一
|
||||
*
|
||||
* @param arr 列数据
|
||||
* @return 归一
|
||||
* @author lroyia
|
||||
* @since 2023年10月24日 10:27:05
|
||||
*/
|
||||
public static double[] standardData(double[] arr) {
|
||||
double[] result = new double[arr.length];
|
||||
double sum = 0;
|
||||
for (double each : arr) {
|
||||
sum += each;
|
||||
}
|
||||
double avg = sum / arr.length;
|
||||
double variance = 0;
|
||||
for (double each : arr) {
|
||||
variance += Math.pow(each - avg, 2);
|
||||
}
|
||||
variance = Math.sqrt(variance);
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
result[i] = (arr[i] - avg) / variance;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算距离
|
||||
*
|
||||
|
|
@ -127,12 +169,9 @@ public abstract class DataUtil {
|
|||
return Math.sqrt(stdDev);
|
||||
}
|
||||
|
||||
/**
|
||||
* 归一结果
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class ToOneResult {
|
||||
public static class MaxMin {
|
||||
/**
|
||||
* 最大值
|
||||
*/
|
||||
|
|
@ -142,10 +181,22 @@ public abstract class DataUtil {
|
|||
* 最小值
|
||||
*/
|
||||
private double min;
|
||||
}
|
||||
|
||||
/**
|
||||
* 归一结果
|
||||
*/
|
||||
private double[] result;
|
||||
@Getter
|
||||
public static class ToOneResult extends MaxMin {
|
||||
|
||||
/**
|
||||
* 归一结果
|
||||
*/
|
||||
private final double[] result;
|
||||
|
||||
public ToOneResult(double max, double min, double[] result) {
|
||||
super(max, min);
|
||||
this.result = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue