diff --git a/pom.xml b/pom.xml
index 97d8dda..8c94eb4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
4.0.0
io.lroyia
- market-ent-classift-demo
+ market-ent-classify-demo
1.0.0-SNAPSHOT
jar
${project.artifactId}
diff --git a/src/main/java/io/lroyia/KnnRun.java b/src/main/java/io/lroyia/KnnRun.java
index 7151d84..64a871c 100644
--- a/src/main/java/io/lroyia/KnnRun.java
+++ b/src/main/java/io/lroyia/KnnRun.java
@@ -1,5 +1,11 @@
package io.lroyia;
+import io.lroyia.bean.EntCalInfo;
+import io.lroyia.model.knn.KNNModel;
+import io.lroyia.util.DataUtil;
+
+import java.util.List;
+
/**
* 程序入口
*
@@ -9,6 +15,33 @@ package io.lroyia;
public class KnnRun {
public static void main(String[] args) {
+ List allCalInfo = DataUtil.getAllCalInfo();
+ int endIndex = allCalInfo.size() / 3 * 2;
+ double[][] trainData = new double[endIndex][];
+ String[] trainResult = new String[endIndex];
+ double[][] testData = new double[allCalInfo.size() - endIndex][];
+ String[] testResult = new String[allCalInfo.size() - endIndex];
+
+ for (int i = 0; i < allCalInfo.size(); i++) {
+ EntCalInfo each = allCalInfo.get(i);
+ double[] data = new double[]{
+ each.getEstDate(),
+ each.getEntType(),
+ each.getRegCap(),
+ each.getIndustryCo(),
+ each.getRegState()
+ };
+ if (i < endIndex) {
+ trainData[i] = data;
+ trainResult[i] = each.getRiskLevel();
+ } else {
+ testData[i - endIndex] = data;
+ testResult[i - endIndex] = each.getRiskLevel();
+ }
+ }
+
+ KNNModel knnModel = new KNNModel<>(5, trainData, trainResult, true);
+ System.out.println(knnModel.test(testData, testResult));
}
}
diff --git a/src/main/java/io/lroyia/bean/EntInfo.java b/src/main/java/io/lroyia/bean/EntInfo.java
index c2783af..4f0b3c5 100644
--- a/src/main/java/io/lroyia/bean/EntInfo.java
+++ b/src/main/java/io/lroyia/bean/EntInfo.java
@@ -77,6 +77,7 @@ public class EntInfo implements Serializable {
result.setPripid(pripid);
result.setEntName(entName);
result.setIndustryPhy(industryPhy);
+ result.setRiskLevel(riskLevel);
if (estDate != null) {
LocalDate now = LocalDate.now();
result.setEstDate(now.getYear() - estDate.getYear());
@@ -87,8 +88,10 @@ public class EntInfo implements Serializable {
if (regCap != null) {
result.setRegCap(regCap.doubleValue());
}
- if (StringUtils.isNotBlank(industryCo)) {
+ if (StringUtils.isNotBlank(industryCo) && StringUtils.isNumeric(industryCo)) {
result.setIndustryCo(Double.parseDouble(industryCo));
+ } else {
+ result.setIndustryCo(0);
}
if (StringUtils.isNotBlank(regState)) {
result.setRegState(Double.parseDouble(regState));
diff --git a/src/main/java/io/lroyia/model/knn/KDTreeNode.java b/src/main/java/io/lroyia/model/knn/KDTreeNode.java
index 9d0afac..c7fea64 100644
--- a/src/main/java/io/lroyia/model/knn/KDTreeNode.java
+++ b/src/main/java/io/lroyia/model/knn/KDTreeNode.java
@@ -2,7 +2,8 @@ package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import lombok.AllArgsConstructor;
-import lombok.Data;
+import lombok.Getter;
+import lombok.Setter;
/**
* kd树节点
@@ -10,7 +11,8 @@ import lombok.Data;
* @author lroyia
* @since 2023/10/28 17:29
**/
-@Data
+@Getter
+@Setter
@AllArgsConstructor
public class KDTreeNode {
diff --git a/src/main/java/io/lroyia/model/knn/KNNModel.java b/src/main/java/io/lroyia/model/knn/KNNModel.java
index 970bc58..4a35bca 100644
--- a/src/main/java/io/lroyia/model/knn/KNNModel.java
+++ b/src/main/java/io/lroyia/model/knn/KNNModel.java
@@ -55,13 +55,13 @@ public class KNNModel implements ClassifyModel {
this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length];
+ this.minArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) {
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
this.maxArray[i] = maxMin.getMax();
this.minArray[i] = maxMin.getMin();
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
nTrainData[i] = toOneResult.getResult();
-
}
this.trainData = nTrainData;
setUseKdTree(useKdTree);
diff --git a/src/main/java/io/lroyia/util/DataUtil.java b/src/main/java/io/lroyia/util/DataUtil.java
index fc0a798..cd1c7cc 100644
--- a/src/main/java/io/lroyia/util/DataUtil.java
+++ b/src/main/java/io/lroyia/util/DataUtil.java
@@ -51,10 +51,7 @@ public abstract class DataUtil {
String[] dateArr = estDateStr.split(" ")[0].split("-");
atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2])));
}
- String entType = each.get("SUBENTTYPE");
- if (StringUtils.isBlank(entType)) {
- entType = each.get("ENTTYPE");
- }
+ String entType = each.get("ENTTYPE");
atom.setEntType(entType);
String regCap = each.get("REGCAP");
if (StringUtils.isNotBlank(regCap)) {
@@ -63,6 +60,7 @@ public abstract class DataUtil {
atom.setIndustryPhy(each.get("INDUSTRYPHY"));
atom.setIndustryCo(each.get("INDUSTRYCO"));
atom.setRegState(each.get("ENTSTATE"));
+ atom.setRiskLevel(each.get("RISKLEVEL"));
}
return result;
} catch (IOException e) {