补充程序入口部分

This commit is contained in:
黎润豪 2023-10-30 11:09:12 +08:00
parent 7602982178
commit a17158f073
6 changed files with 45 additions and 9 deletions

View File

@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>io.lroyia</groupId> <groupId>io.lroyia</groupId>
<artifactId>market-ent-classift-demo</artifactId> <artifactId>market-ent-classify-demo</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<name>${project.artifactId}</name> <name>${project.artifactId}</name>

View File

@ -1,5 +1,11 @@
package io.lroyia; package io.lroyia;
import io.lroyia.bean.EntCalInfo;
import io.lroyia.model.knn.KNNModel;
import io.lroyia.util.DataUtil;
import java.util.List;
/** /**
* 程序入口 * 程序入口
* *
@ -9,6 +15,33 @@ package io.lroyia;
public class KnnRun { public class KnnRun {
public static void main(String[] args) { public static void main(String[] args) {
List<EntCalInfo> allCalInfo = DataUtil.getAllCalInfo();
int endIndex = allCalInfo.size() / 3 * 2;
double[][] trainData = new double[endIndex][];
String[] trainResult = new String[endIndex];
double[][] testData = new double[allCalInfo.size() - endIndex][];
String[] testResult = new String[allCalInfo.size() - endIndex];
for (int i = 0; i < allCalInfo.size(); i++) {
EntCalInfo each = allCalInfo.get(i);
double[] data = new double[]{
each.getEstDate(),
each.getEntType(),
each.getRegCap(),
each.getIndustryCo(),
each.getRegState()
};
if (i < endIndex) {
trainData[i] = data;
trainResult[i] = each.getRiskLevel();
} else {
testData[i - endIndex] = data;
testResult[i - endIndex] = each.getRiskLevel();
}
}
KNNModel<String> knnModel = new KNNModel<>(5, trainData, trainResult, true);
System.out.println(knnModel.test(testData, testResult));
} }
} }

View File

@ -77,6 +77,7 @@ public class EntInfo implements Serializable {
result.setPripid(pripid); result.setPripid(pripid);
result.setEntName(entName); result.setEntName(entName);
result.setIndustryPhy(industryPhy); result.setIndustryPhy(industryPhy);
result.setRiskLevel(riskLevel);
if (estDate != null) { if (estDate != null) {
LocalDate now = LocalDate.now(); LocalDate now = LocalDate.now();
result.setEstDate(now.getYear() - estDate.getYear()); result.setEstDate(now.getYear() - estDate.getYear());
@ -87,8 +88,10 @@ public class EntInfo implements Serializable {
if (regCap != null) { if (regCap != null) {
result.setRegCap(regCap.doubleValue()); result.setRegCap(regCap.doubleValue());
} }
if (StringUtils.isNotBlank(industryCo)) { if (StringUtils.isNotBlank(industryCo) && StringUtils.isNumeric(industryCo)) {
result.setIndustryCo(Double.parseDouble(industryCo)); result.setIndustryCo(Double.parseDouble(industryCo));
} else {
result.setIndustryCo(0);
} }
if (StringUtils.isNotBlank(regState)) { if (StringUtils.isNotBlank(regState)) {
result.setRegState(Double.parseDouble(regState)); result.setRegState(Double.parseDouble(regState));

View File

@ -2,7 +2,8 @@ package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData; import io.lroyia.bean.ClassifyModelData;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Getter;
import lombok.Setter;
/** /**
* kd树节点 * kd树节点
@ -10,7 +11,8 @@ import lombok.Data;
* @author lroyia * @author lroyia
* @since 2023/10/28 17:29 * @since 2023/10/28 17:29
**/ **/
@Data @Getter
@Setter
@AllArgsConstructor @AllArgsConstructor
public class KDTreeNode<ResultT> { public class KDTreeNode<ResultT> {

View File

@ -55,13 +55,13 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
this.trainResult = trainResult; this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][]; double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length]; this.maxArray = new double[trainData.length];
this.minArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) { for (int i = 0; i < trainData.length; i++) {
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]); DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
this.maxArray[i] = maxMin.getMax(); this.maxArray[i] = maxMin.getMax();
this.minArray[i] = maxMin.getMin(); this.minArray[i] = maxMin.getMin();
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i])); DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
nTrainData[i] = toOneResult.getResult(); nTrainData[i] = toOneResult.getResult();
} }
this.trainData = nTrainData; this.trainData = nTrainData;
setUseKdTree(useKdTree); setUseKdTree(useKdTree);

View File

@ -51,10 +51,7 @@ public abstract class DataUtil {
String[] dateArr = estDateStr.split(" ")[0].split("-"); String[] dateArr = estDateStr.split(" ")[0].split("-");
atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2]))); atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2])));
} }
String entType = each.get("SUBENTTYPE"); String entType = each.get("ENTTYPE");
if (StringUtils.isBlank(entType)) {
entType = each.get("ENTTYPE");
}
atom.setEntType(entType); atom.setEntType(entType);
String regCap = each.get("REGCAP"); String regCap = each.get("REGCAP");
if (StringUtils.isNotBlank(regCap)) { if (StringUtils.isNotBlank(regCap)) {
@ -63,6 +60,7 @@ public abstract class DataUtil {
atom.setIndustryPhy(each.get("INDUSTRYPHY")); atom.setIndustryPhy(each.get("INDUSTRYPHY"));
atom.setIndustryCo(each.get("INDUSTRYCO")); atom.setIndustryCo(each.get("INDUSTRYCO"));
atom.setRegState(each.get("ENTSTATE")); atom.setRegState(each.get("ENTSTATE"));
atom.setRiskLevel(each.get("RISKLEVEL"));
} }
return result; return result;
} catch (IOException e) { } catch (IOException e) {