Skip to content

Commit 297e868

Browse files
committed
cleanup for next release
1 parent 399664f commit 297e868

4 files changed

Lines changed: 87 additions & 32 deletions

File tree

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>rocks.vilaverde</groupId>
88
<artifactId>scikit-learn-2-java</artifactId>
9-
<version>1.0.2-SNAPSHOT</version>
9+
<version>1.0.2</version>
1010

1111
<name>${project.groupId}:${project.artifactId}</name>
1212
<description>A sklearn exported_text models parser for executing in the Java runtime.</description>

src/main/java/rocks/vilaverde/classifier/dt/AbstractDecisionTreeVisitor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ public void visit(ChoiceNode object) {
2828
public void visit(DecisionNode object) {
2929
visitBase(object);
3030

31-
for (TreeNode child : object.getChildren()) {
32-
child.accept(this);
33-
}
31+
object.getLeft().accept(this);
32+
object.getRight().accept(this);
3433
}
3534
}

src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,76 @@
55

66
/**
77
* Represents a decision in the DecisionTreeClassifier. The decision will have
8-
* a left and right hand choice to be evaluated. Choices may have nested DecisionNodes.
8+
* a left and right hand {@link ChoiceNode} to be evaluated.
9+
* A {@link ChoiceNode} may have nested {@link DecisionNode} or {@link EndNode}.
910
*/
1011
class DecisionNode extends TreeNode {
1112

1213
private final String featureName;
1314

14-
private final List<TreeNode> children = new ArrayList<>(2);
15-
16-
public List<TreeNode> getChildren() {
17-
return children;
18-
}
15+
private ChoiceNode left;
16+
private ChoiceNode right;
1917

18+
/**
19+
* Factory method to create a {@link DecisionNode}.
20+
* @param feature the name of the feature.
21+
* @return DecisionNode
22+
*/
2023
public static DecisionNode create(String feature) {
2124
return new DecisionNode(feature);
2225
}
2326

27+
/**
28+
* Private Constructor.
29+
* @param featureName
30+
*/
2431
private DecisionNode(String featureName) {
25-
this.featureName = featureName;
32+
this.featureName = featureName.intern();
33+
}
34+
35+
/**
36+
* Getter for the left.
37+
* @return ChoiceNode
38+
*/
39+
public ChoiceNode getLeft() {
40+
return left;
41+
}
42+
43+
/**
44+
* Left hand side of the decision.
45+
* @param left the left {@link ChoiceNode}
46+
*/
47+
public void setLeft(ChoiceNode left) {
48+
this.left = left;
49+
}
50+
51+
/**
52+
* Getter for the right.
53+
* @return ChoiceNode
54+
*/
55+
public ChoiceNode getRight() {
56+
return right;
57+
}
58+
59+
/**
60+
* Right hand side of a decision.
61+
* @param right the right {@link ChoiceNode}
62+
*/
63+
public void setRight(ChoiceNode right) {
64+
this.right = right;
2665
}
2766

67+
/**
68+
* @return true when the left and right choice are set on this decision.
69+
*/
70+
public boolean isComplete() {
71+
return getLeft() != null && getRight() != null;
72+
}
73+
74+
/**
75+
* Getter for the feature used in this decision node.
76+
* @return String
77+
*/
2878
public String getFeatureName() {
2979
return featureName;
3080
}
@@ -36,6 +86,15 @@ public void accept(AbstractDecisionTreeVisitor visitor) {
3686

3787
@Override
3888
public String toString() {
39-
return String.format("%s, %d operations", getFeatureName(), getChildren().size());
89+
int count = 0;
90+
if (getLeft() != null) {
91+
count++;
92+
}
93+
94+
if (getRight() != null) {
95+
count++;
96+
}
97+
98+
return String.format("%s, %d operations", getFeatureName(), count);
4099
}
41100
}

src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import rocks.vilaverde.classifier.Operator;
44
import rocks.vilaverde.classifier.Prediction;
55

6+
import java.awt.*;
67
import java.io.BufferedReader;
78
import java.io.Reader;
89
import java.util.Map;
@@ -36,7 +37,6 @@ public static <T> DecisionTreeClassifier<T> parse(Reader reader, PredictionFacto
3637

3738
private final PredictionFactory<T> predictionFactory;
3839
private DecisionNode root;
39-
4040
private Set<String> featureNames;
4141

4242
/**
@@ -81,28 +81,19 @@ public Prediction<T> getClassification(Map<String, Double> features) {
8181
if (currentNode != null) {
8282
DecisionNode decisionNode = ((DecisionNode) currentNode);
8383

84-
ChoiceNode selection = null;
85-
for (TreeNode child : decisionNode.getChildren()) {
86-
87-
Double featureValue = features.get(decisionNode.getFeatureName());
88-
if (featureValue == null) {
89-
featureValue = Double.NaN;
90-
}
91-
92-
// find the path to traverse by evaluating all choices
93-
94-
if (((ChoiceNode) child).eval(featureValue)) {
95-
selection = (ChoiceNode) child;
96-
break;
97-
}
84+
Double featureValue = features.get(decisionNode.getFeatureName());
85+
if (featureValue == null) {
86+
featureValue = Double.NaN;
9887
}
9988

100-
if (selection != null) {
101-
currentNode = selection.getChild();
89+
if (decisionNode.getLeft().eval(featureValue)) {
90+
currentNode = decisionNode.getLeft().getChild();
91+
} else if (decisionNode.getRight().eval(featureValue)) {
92+
currentNode = decisionNode.getRight().getChild();
10293
} else {
10394
// if I get here something is wrong since none of the branches evaluated to true
10495
throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'",
105-
decisionNode.getFeatureName()));
96+
decisionNode.getFeatureName()));
10697
}
10798
}
10899
}
@@ -174,7 +165,7 @@ private void processChildNode(Stack<TreeNode> stack, String line) throws Excepti
174165
// if the current choice has been popped check if the decision node
175166
// has 2 operations, and if so pop that one as well and any other
176167
// completed decision nodes.
177-
while (stack.size() > 1 && ((DecisionNode)stack.peek()).getChildren().size() == 2 ) {
168+
while (stack.size() > 1 && ((DecisionNode)stack.peek()).isComplete() ) {
178169
stack.pop();
179170
}
180171
}
@@ -215,7 +206,13 @@ private void processDecisionNode(Stack<TreeNode> stack, String line) {
215206
}
216207

217208
ChoiceNode choice = ChoiceNode.create(op, value);
218-
decisionNode.getChildren().add(choice);
209+
210+
if (decisionNode.getLeft() == null) {
211+
decisionNode.setLeft(choice);
212+
} else {
213+
decisionNode.setRight(choice);
214+
}
215+
219216
stack.push(choice);
220217
}
221218

0 commit comments

Comments
 (0)