Predict_proba functionality to Random Forest Classifier#360
Predict_proba functionality to Random Forest Classifier#360skywardfire1 wants to merge 9 commits intosmartcorelib:developmentfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## development #360 +/- ##
===============================================
- Coverage 45.59% 44.69% -0.90%
===============================================
Files 93 95 +2
Lines 8034 8054 +20
===============================================
- Hits 3663 3600 -63
- Misses 4371 4454 +83 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
applied auto formatting. |
|
wow. thank you. will take a look asap |
|
some additional info. This is how it looks like in my project labels - True: [1, 1, 4, 0, 0, 5, 3, 0, 0, 5, 0, 0, 0, 0, 0, 4, 4, 1, 2, 3, 5, 1, 4, 0, 3, 1, 0, 0, 3, 3, 0, 0, 0, 4, 5, 1, 1, 0, 0, 1, 5, 2, 4, 4, 0, 0, 1, 1, 3, 4, 0, 0, 4, 2, 2, 3, 4, 5, 5, 0, 0, 5, 0, 0, 0, 4, 4, 1, 5]
labels - Predicted: [5, 1, 4, 0, 0, 5, 2, 0, 0, 5, 0, 0, 4, 0, 0, 4, 4, 5, 2, 3, 5, 5, 4, 0, 3, 3, 0, 0, 3, 3, 0, 0, 0, 4, 5, 1, 1, 0, 0, 1, 5, 2, 4, 1, 0, 0, 1, 5, 3, 4, 0, 0, 4, 4, 4, 3, 4, 5, 5, 0, 0, 5, 0, 0, 0, 4, 4, 1, 1]
Primary probabilities (first 5 samples):
[[0.0, 0.2, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0],
[0.0, 0.9, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0],
[0.1, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0],
[0.9, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] And [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] result is funny to me since the answer is correct, but so shy |
|
this looks OK. it would be nice to have a test using the iris dataset like this one: so the results can be checked on known results. |
|
maybe a little better error handling like: Rest looks good. |
…est uses Iris dataset, and consists of 4 checks. The 2nd test consists of 2 checks.
fc969fe to
8f7b17a
Compare
|
I revisited the tests. As said in commit comment, now there are 2 tests. The first one uses Iris dataset from the beginning of the file, and performs 4 checks. Everything builds and works perfectly at my side, clippy and fmt --all shows no issues, so I have no idea why builds fail. What about error handling. I can physically add error checks but it seems useless since there are no operations could possibly return an error if the user doesn't break the API. |
|
still, checks don't look good. I offer to throw this PR away, I'll make another one soon after, with same functionality. |
|
no problem. proceed as you find right |
Checklist
This PR adds:
predict_probabehavior.Previously, probability predictions from Decision Tree returned one-hot encoded vectors (1.0 for the majority class, 0.0 for others), which did not reflect actual class distributions in leaf nodes.
While we could still use it in Random Forest, this approach would not provide calibrated probability estimates.
Changes:
class_distribution: Vec<usize>field to store the class histogram in each node. This data was already being computed during training but was not persisted.predict_proba_for_row_real()method that returns proper probability distributions based on leaf class counts. The originalpredict_proba()method remains unchanged for backward compatibility. Hope, we will obsolete it one day.predict_proba()method that averages probability distributions from all trees (scikit-learn style), rather than averaging hard class predictions.debug_assert_eq!, just to feel better.Backward Compatibility:
No breaking changes. All existing APIs remain intact.
Note:
Current predict_proba function returns
Vec<Vec<f64>>, notDenseMatrix<f64>, since I didn't find any examples on what is the default behavior or standard for this.