补充程序入口部分
This commit is contained in:
parent
7602982178
commit
a17158f073
2
pom.xml
2
pom.xml
|
|
@ -5,7 +5,7 @@
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>io.lroyia</groupId>
|
<groupId>io.lroyia</groupId>
|
||||||
<artifactId>market-ent-classift-demo</artifactId>
|
<artifactId>market-ent-classify-demo</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
package io.lroyia;
|
package io.lroyia;
|
||||||
|
|
||||||
|
import io.lroyia.bean.EntCalInfo;
|
||||||
|
import io.lroyia.model.knn.KNNModel;
|
||||||
|
import io.lroyia.util.DataUtil;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 程序入口
|
* 程序入口
|
||||||
*
|
*
|
||||||
|
|
@ -9,6 +15,33 @@ package io.lroyia;
|
||||||
public class KnnRun {
|
public class KnnRun {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
List<EntCalInfo> allCalInfo = DataUtil.getAllCalInfo();
|
||||||
|
int endIndex = allCalInfo.size() / 3 * 2;
|
||||||
|
double[][] trainData = new double[endIndex][];
|
||||||
|
String[] trainResult = new String[endIndex];
|
||||||
|
|
||||||
|
double[][] testData = new double[allCalInfo.size() - endIndex][];
|
||||||
|
String[] testResult = new String[allCalInfo.size() - endIndex];
|
||||||
|
|
||||||
|
for (int i = 0; i < allCalInfo.size(); i++) {
|
||||||
|
EntCalInfo each = allCalInfo.get(i);
|
||||||
|
double[] data = new double[]{
|
||||||
|
each.getEstDate(),
|
||||||
|
each.getEntType(),
|
||||||
|
each.getRegCap(),
|
||||||
|
each.getIndustryCo(),
|
||||||
|
each.getRegState()
|
||||||
|
};
|
||||||
|
if (i < endIndex) {
|
||||||
|
trainData[i] = data;
|
||||||
|
trainResult[i] = each.getRiskLevel();
|
||||||
|
} else {
|
||||||
|
testData[i - endIndex] = data;
|
||||||
|
testResult[i - endIndex] = each.getRiskLevel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
KNNModel<String> knnModel = new KNNModel<>(5, trainData, trainResult, true);
|
||||||
|
System.out.println(knnModel.test(testData, testResult));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ public class EntInfo implements Serializable {
|
||||||
result.setPripid(pripid);
|
result.setPripid(pripid);
|
||||||
result.setEntName(entName);
|
result.setEntName(entName);
|
||||||
result.setIndustryPhy(industryPhy);
|
result.setIndustryPhy(industryPhy);
|
||||||
|
result.setRiskLevel(riskLevel);
|
||||||
if (estDate != null) {
|
if (estDate != null) {
|
||||||
LocalDate now = LocalDate.now();
|
LocalDate now = LocalDate.now();
|
||||||
result.setEstDate(now.getYear() - estDate.getYear());
|
result.setEstDate(now.getYear() - estDate.getYear());
|
||||||
|
|
@ -87,8 +88,10 @@ public class EntInfo implements Serializable {
|
||||||
if (regCap != null) {
|
if (regCap != null) {
|
||||||
result.setRegCap(regCap.doubleValue());
|
result.setRegCap(regCap.doubleValue());
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotBlank(industryCo)) {
|
if (StringUtils.isNotBlank(industryCo) && StringUtils.isNumeric(industryCo)) {
|
||||||
result.setIndustryCo(Double.parseDouble(industryCo));
|
result.setIndustryCo(Double.parseDouble(industryCo));
|
||||||
|
} else {
|
||||||
|
result.setIndustryCo(0);
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotBlank(regState)) {
|
if (StringUtils.isNotBlank(regState)) {
|
||||||
result.setRegState(Double.parseDouble(regState));
|
result.setRegState(Double.parseDouble(regState));
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ package io.lroyia.model.knn;
|
||||||
|
|
||||||
import io.lroyia.bean.ClassifyModelData;
|
import io.lroyia.bean.ClassifyModelData;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* kd树节点
|
* kd树节点
|
||||||
|
|
@ -10,7 +11,8 @@ import lombok.Data;
|
||||||
* @author lroyia
|
* @author lroyia
|
||||||
* @since 2023/10/28 17:29
|
* @since 2023/10/28 17:29
|
||||||
**/
|
**/
|
||||||
@Data
|
@Getter
|
||||||
|
@Setter
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class KDTreeNode<ResultT> {
|
public class KDTreeNode<ResultT> {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,13 @@ public class KNNModel<ResultT> implements ClassifyModel<ResultT> {
|
||||||
this.trainResult = trainResult;
|
this.trainResult = trainResult;
|
||||||
double[][] nTrainData = new double[trainData.length][];
|
double[][] nTrainData = new double[trainData.length][];
|
||||||
this.maxArray = new double[trainData.length];
|
this.maxArray = new double[trainData.length];
|
||||||
|
this.minArray = new double[trainData.length];
|
||||||
for (int i = 0; i < trainData.length; i++) {
|
for (int i = 0; i < trainData.length; i++) {
|
||||||
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
|
DataUtil.MaxMin maxMin = DataUtil.getMaxMin(trainData[i]);
|
||||||
this.maxArray[i] = maxMin.getMax();
|
this.maxArray[i] = maxMin.getMax();
|
||||||
this.minArray[i] = maxMin.getMin();
|
this.minArray[i] = maxMin.getMin();
|
||||||
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
|
DataUtil.ToOneResult toOneResult = DataUtil.columnToOne(DataUtil.standardData(trainData[i]));
|
||||||
nTrainData[i] = toOneResult.getResult();
|
nTrainData[i] = toOneResult.getResult();
|
||||||
|
|
||||||
}
|
}
|
||||||
this.trainData = nTrainData;
|
this.trainData = nTrainData;
|
||||||
setUseKdTree(useKdTree);
|
setUseKdTree(useKdTree);
|
||||||
|
|
|
||||||
|
|
@ -51,10 +51,7 @@ public abstract class DataUtil {
|
||||||
String[] dateArr = estDateStr.split(" ")[0].split("-");
|
String[] dateArr = estDateStr.split(" ")[0].split("-");
|
||||||
atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2])));
|
atom.setEstDate(LocalDate.of(Integer.parseInt(dateArr[0]), Integer.parseInt(dateArr[1]), Integer.parseInt(dateArr[2])));
|
||||||
}
|
}
|
||||||
String entType = each.get("SUBENTTYPE");
|
String entType = each.get("ENTTYPE");
|
||||||
if (StringUtils.isBlank(entType)) {
|
|
||||||
entType = each.get("ENTTYPE");
|
|
||||||
}
|
|
||||||
atom.setEntType(entType);
|
atom.setEntType(entType);
|
||||||
String regCap = each.get("REGCAP");
|
String regCap = each.get("REGCAP");
|
||||||
if (StringUtils.isNotBlank(regCap)) {
|
if (StringUtils.isNotBlank(regCap)) {
|
||||||
|
|
@ -63,6 +60,7 @@ public abstract class DataUtil {
|
||||||
atom.setIndustryPhy(each.get("INDUSTRYPHY"));
|
atom.setIndustryPhy(each.get("INDUSTRYPHY"));
|
||||||
atom.setIndustryCo(each.get("INDUSTRYCO"));
|
atom.setIndustryCo(each.get("INDUSTRYCO"));
|
||||||
atom.setRegState(each.get("ENTSTATE"));
|
atom.setRegState(each.get("ENTSTATE"));
|
||||||
|
atom.setRiskLevel(each.get("RISKLEVEL"));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue