diff --git a/pom.xml b/pom.xml index 97d8dda..8c94eb4 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ 4.0.0 io.lroyia - market-ent-classift-demo + market-ent-classify-demo 1.0.0-SNAPSHOT jar ${project.artifactId} diff --git a/src/main/java/io/lroyia/KnnRun.java b/src/main/java/io/lroyia/KnnRun.java index 7151d84..64a871c 100644 --- a/src/main/java/io/lroyia/KnnRun.java +++ b/src/main/java/io/lroyia/KnnRun.java @@ -1,5 +1,11 @@ package io.lroyia; +import io.lroyia.bean.EntCalInfo; +import io.lroyia.model.knn.KNNModel; +import io.lroyia.util.DataUtil; + +import java.util.List; + /** * 程序入口 * @@ -9,6 +15,33 @@ package io.lroyia; public class KnnRun { public static void main(String[] args) { + List allCalInfo = DataUtil.getAllCalInfo(); + int endIndex = allCalInfo.size() / 3 * 2; + double[][] trainData = new double[endIndex][]; + String[] trainResult = new String[endIndex]; + double[][] testData = new double[allCalInfo.size() - endIndex][]; + String[] testResult = new String[allCalInfo.size() - endIndex]; + + for (int i = 0; i < allCalInfo.size(); i++) { + EntCalInfo each = allCalInfo.get(i); + double[] data = new double[]{ + each.getEstDate(), + each.getEntType(), + each.getRegCap(), + each.getIndustryCo(), + each.getRegState() + }; + if (i < endIndex) { + trainData[i] = data; + trainResult[i] = each.getRiskLevel(); + } else { + testData[i - endIndex] = data; + testResult[i - endIndex] = each.getRiskLevel(); + } + } + + KNNModel knnModel = new KNNModel<>(5, trainData, trainResult, true); + System.out.println(knnModel.test(testData, testResult)); } } diff --git a/src/main/java/io/lroyia/bean/EntInfo.java b/src/main/java/io/lroyia/bean/EntInfo.java index c2783af..4f0b3c5 100644 --- a/src/main/java/io/lroyia/bean/EntInfo.java +++ b/src/main/java/io/lroyia/bean/EntInfo.java @@ -77,6 +77,7 @@ public class EntInfo implements Serializable { result.setPripid(pripid); result.setEntName(entName); result.setIndustryPhy(industryPhy); + result.setRiskLevel(riskLevel); if (estDate != null) { LocalDate now = LocalDate.now(); result.setEstDate(now.getYear() - estDate.getYear()); @@ -87,8 +88,10 @@ public class EntInfo implements Serializable { if (regCap != null) { result.setRegCap(regCap.doubleValue()); } - if (StringUtils.isNotBlank(industryCo)) { + if (StringUtils.isNotBlank(industryCo) && StringUtils.isNumeric(industryCo)) { result.setIndustryCo(Double.parseDouble(industryCo)); + } else { + result.setIndustryCo(0); } if (StringUtils.isNotBlank(regState)) { result.setRegState(Double.parseDouble(regState)); diff --git a/src/main/java/io/lroyia/model/knn/KDTreeNode.java b/src/main/java/io/lroyia/model/knn/KDTreeNode.java index 9d0afac..c7fea64 100644 --- a/src/main/java/io/lroyia/model/knn/KDTreeNode.java +++ b/src/main/java/io/lroyia/model/knn/KDTreeNode.java @@ -2,7 +2,8 @@ package io.lroyia.model.knn; import io.lroyia.bean.ClassifyModelData; import lombok.AllArgsConstructor; -import lombok.Data; +import lombok.Getter; +import lombok.Setter; /** * kd树节点 @@ -10,7 +11,8 @@ import lombok.Data; * @author lroyia * @since 2023/10/28 17:29 **/ -@Data +@Getter +@Setter @AllArgsConstructor public class KDTreeNode { diff --git a/src/main/java/io/lroyia/model/knn/KNNModel.java b/src/main/java/io/lroyia/model/knn/KNNModel.java index 970bc58..4a35bca 100644 --- a/src/main/java/io/lroyia/model/knn/KNNModel.java +++ b/src/main/java/io/lroyia/model/knn/KNNModel.java @@ -55,13 +55,13 @@ public class KNNModel implements ClassifyModel { this.trainResult = trainResult; double[][] nTrainData = new double[trainData.length][]; this.maxArray = new double[trainData.length]; + this.minArray = new double[trainData.length]; for (int i = 0; i < trainData.length; i++) { DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]); this.maxArray[i] = maxMin.getMax(); this.minArray[i] = maxMin.getMin(); DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i])); nTrainData[i] = toOneResult.getResult(); - } this.trainData = nTrainData; setUseKdTree(useKdTree); diff --git a/src/main/java/io/lroyia/util/DataUtil.java b/src/main/java/io/lroyia/util/DataUtil.java index fc0a798..cd1c7cc 100644 --- a/src/main/java/io/lroyia/util/DataUtil.java +++ b/src/main/java/io/lroyia/util/DataUtil.java @@ -51,10 +51,7 @@ public abstract class DataUtil { String[] dateArr = estDateStr.split(" ")[0].split("-"); atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2]))); } - String entType = each.get("SUBENTTYPE"); - if (StringUtils.isBlank(entType)) { - entType = each.get("ENTTYPE"); - } + String entType = each.get("ENTTYPE"); atom.setEntType(entType); String regCap = each.get("REGCAP"); if (StringUtils.isNotBlank(regCap)) { @@ -63,6 +60,7 @@ public abstract class DataUtil { atom.setIndustryPhy(each.get("INDUSTRYPHY")); atom.setIndustryCo(each.get("INDUSTRYCO")); atom.setRegState(each.get("ENTSTATE")); + atom.setRiskLevel(each.get("RISKLEVEL")); } return result; } catch (IOException e) {