补充程序入口部分

This commit is contained in:
黎润豪 2023-10-30 11:09:12 +08:00
parent 7602982178
commit a17158f073
6 changed files with 45 additions and 9 deletions

View File

@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>io.lroyia</groupId>
<artifactId>market-ent-classift-demo</artifactId>
<artifactId>market-ent-classify-demo</artifactId>
<version>1.0.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>${project.artifactId}</name>

View File

@ -1,5 +1,11 @@
package io.lroyia;
import io.lroyia.bean.EntCalInfo;
import io.lroyia.model.knn.KNNModel;
import io.lroyia.util.DataUtil;
import java.util.List;
/**
* 程序入口
*
@ -9,6 +15,33 @@ package io.lroyia;
public class KnnRun {
public static void main(String[] args) {
List<EntCalInfo> allCalInfo = DataUtil.getAllCalInfo();
int endIndex = allCalInfo.size() / 3 * 2;
double[][] trainData = new double[endIndex][];
String[] trainResult = new String[endIndex];
double[][] testData = new double[allCalInfo.size() - endIndex][];
String[] testResult = new String[allCalInfo.size() - endIndex];
for (int i = 0; i < allCalInfo.size(); i++) {
EntCalInfo each = allCalInfo.get(i);
double[] data = new double[]{
each.getEstDate(),
each.getEntType(),
each.getRegCap(),
each.getIndustryCo(),
each.getRegState()
};
if (i < endIndex) {
trainData[i] = data;
trainResult[i] = each.getRiskLevel();
} else {
testData[i - endIndex] = data;
testResult[i - endIndex] = each.getRiskLevel();
}
}
KNNModel<String> knnModel = new KNNModel<>(5, trainData, trainResult, true);
System.out.println(knnModel.test(testData, testResult));
}
}

View File

@ -77,6 +77,7 @@ public class EntInfo implements Serializable {
result.setPripid(pripid);
result.setEntName(entName);
result.setIndustryPhy(industryPhy);
result.setRiskLevel(riskLevel);
if (estDate != null) {
LocalDate now = LocalDate.now();
result.setEstDate(now.getYear() - estDate.getYear());
@ -87,8 +88,10 @@ public class EntInfo implements Serializable {
if (regCap != null) {
result.setRegCap(regCap.doubleValue());
}
if (StringUtils.isNotBlank(industryCo)) {
if (StringUtils.isNotBlank(industryCo) && StringUtils.isNumeric(industryCo)) {
result.setIndustryCo(Double.parseDouble(industryCo));
} else {
result.setIndustryCo(0);
}
if (StringUtils.isNotBlank(regState)) {
result.setRegState(Double.parseDouble(regState));

View File

@ -2,7 +2,8 @@ package io.lroyia.model.knn;
import io.lroyia.bean.ClassifyModelData;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
/**
* kd树节点
@ -10,7 +11,8 @@ import lombok.Data;
* @author lroyia
* @since 2023/10/28 17:29
**/
@Data
@Getter
@Setter
@AllArgsConstructor
public class KDTreeNode<ResultT> {

View File

@ -55,13 +55,13 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
this.trainResult = trainResult;
double[][] nTrainData = new double[trainData.length][];
this.maxArray = new double[trainData.length];
this.minArray = new double[trainData.length];
for (int i = 0; i < trainData.length; i++) {
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
this.maxArray[i] = maxMin.getMax();
this.minArray[i] = maxMin.getMin();
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
nTrainData[i] = toOneResult.getResult();
}
this.trainData = nTrainData;
setUseKdTree(useKdTree);

View File

@ -51,10 +51,7 @@ public abstract class DataUtil {
String[] dateArr = estDateStr.split(" ")[0].split("-");
atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2])));
}
String entType = each.get("SUBENTTYPE");
if (StringUtils.isBlank(entType)) {
entType = each.get("ENTTYPE");
}
String entType = each.get("ENTTYPE");
atom.setEntType(entType);
String regCap = each.get("REGCAP");
if (StringUtils.isNotBlank(regCap)) {
@ -63,6 +60,7 @@ public abstract class DataUtil {
atom.setIndustryPhy(each.get("INDUSTRYPHY"));
atom.setIndustryCo(each.get("INDUSTRYCO"));
atom.setRegState(each.get("ENTSTATE"));
atom.setRiskLevel(each.get("RISKLEVEL"));
}
return result;
} catch (IOException e) {