Skip to content

Predict_proba functionality to Random Forest Classifier#360

Closed
skywardfire1 wants to merge 9 commits intosmartcorelib:developmentfrom
skywardfire1:development
Closed

Predict_proba functionality to Random Forest Classifier#360
skywardfire1 wants to merge 9 commits intosmartcorelib:developmentfrom
skywardfire1:development

Conversation

@skywardfire1
Copy link
Contributor

@skywardfire1 skywardfire1 commented Mar 16, 2026

Checklist

  • [yes] My branch is up-to-date with development branch.
  • [yes] Everything works and tested on latest stable Rust.
  • [yes] Coverage and Linting have been applied

This PR adds:

  1. Proper probability estimation to Decision Tree Classifier, matching scikit-learn's predict_proba behavior.
  2. Predict_proba functionality to Random Forest Classifier, also in scikit-learn style, which was the goal.

Previously, probability predictions from Decision Tree returned one-hot encoded vectors (1.0 for the majority class, 0.0 for others), which did not reflect actual class distributions in leaf nodes.
While we could still use it in Random Forest, this approach would not provide calibrated probability estimates.

Changes:

  1. Node structure extended: Added class_distribution: Vec<usize> field to store the class histogram in each node. This data was already being computed during training but was not persisted.
  2. DecisionTreeClassifier: Added predict_proba_for_row_real() method that returns proper probability distributions based on leaf class counts. The original predict_proba() method remains unchanged for backward compatibility. Hope, we will obsolete it one day.
  3. RandomForestClassifier: Added public predict_proba() method that averages probability distributions from all trees (scikit-learn style), rather than averaging hard class predictions.
  4. Testing: Added 5 new tests covering:
    • Probability distributions summing to 1.0
    • Correct class ordering in predictions
    • Mixed-class leaf handling
    • Forest-level probability averaging
  5. Debug assertions: Added 3 debug_assert_eq!, just to feel better.

Backward Compatibility:
No breaking changes. All existing APIs remain intact.

Note:
Current predict_proba function returns Vec<Vec<f64>>, not DenseMatrix<f64>, since I didn't find any examples on what is the default behavior or standard for this.

@skywardfire1 skywardfire1 requested a review from Mec-iS as a code owner March 16, 2026 11:57
@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 56.36364% with 24 lines in your changes missing coverage. Please review.
✅ Project coverage is 44.69%. Comparing base (70d8a0f) to head (fc969fe).
⚠️ Report is 10 commits behind head on development.

Files with missing lines Patch % Lines
src/tree/decision_tree_classifier.rs 62.16% 14 Missing ⚠️
src/ensemble/random_forest_classifier.rs 44.44% 10 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           development     #360      +/-   ##
===============================================
- Coverage        45.59%   44.69%   -0.90%     
===============================================
  Files               93       95       +2     
  Lines             8034     8054      +20     
===============================================
- Hits              3663     3600      -63     
- Misses            4371     4454      +83     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@skywardfire1
Copy link
Contributor Author

applied auto formatting.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Mar 17, 2026

wow. thank you. will take a look asap

@skywardfire1
Copy link
Contributor Author

some additional info. This is how it looks like in my project

labels - True:      [1, 1, 4, 0, 0, 5, 3, 0, 0, 5, 0, 0, 0, 0, 0, 4, 4, 1, 2, 3, 5, 1, 4, 0, 3, 1, 0, 0, 3, 3, 0, 0, 0, 4, 5, 1, 1, 0, 0, 1, 5, 2, 4, 4, 0, 0, 1, 1, 3, 4, 0, 0, 4, 2, 2, 3, 4, 5, 5, 0, 0, 5, 0, 0, 0, 4, 4, 1, 5]
labels - Predicted: [5, 1, 4, 0, 0, 5, 2, 0, 0, 5, 0, 0, 4, 0, 0, 4, 4, 5, 2, 3, 5, 5, 4, 0, 3, 3, 0, 0, 3, 3, 0, 0, 0, 4, 5, 1, 1, 0, 0, 1, 5, 2, 4, 1, 0, 0, 1, 5, 3, 4, 0, 0, 4, 4, 4, 3, 4, 5, 5, 0, 0, 5, 0, 0, 0, 4, 4, 1, 1]

Primary probabilities (first 5 samples):
    [[0.0, 0.2, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0],
     [0.0, 0.9, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0],
     [0.1, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0],
     [0.9, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
     [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] 

And [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] result is funny to me since the answer is correct, but so shy

@Mec-iS
Copy link
Collaborator

Mec-iS commented Mar 19, 2026

this looks OK.

it would be nice to have a test using the iris dataset like this one:

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn fit_predict_iris_oob() { 
        ...

so the results can be checked on known results.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Mar 19, 2026

maybe a little better error handling like:

  pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {                                                 
      let (n, _) = x.shape();                                                                                           
      let mut result = Vec::with_capacity(n);                                                                           
      for i in 0..n {                                                                                                   
          result.push(self.predict_proba_for_row(x, i));                                                                
      }                                                                                                                 
      Ok(result)                                                                                                        
  }  

Rest looks good.

Copy link
Collaborator

@Mec-iS Mec-iS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check out my comments

…est uses Iris dataset, and consists of 4 checks. The 2nd test consists of 2 checks.
@skywardfire1
Copy link
Contributor Author

I revisited the tests. As said in commit comment, now there are 2 tests. The first one uses Iris dataset from the beginning of the file, and performs 4 checks.

Everything builds and works perfectly at my side, clippy and fmt --all shows no issues, so I have no idea why builds fail.

What about error handling. I can physically add error checks but it seems useless since there are no operations could possibly return an error if the user doesn't break the API.

@skywardfire1
Copy link
Contributor Author

still, checks don't look good. I offer to throw this PR away, I'll make another one soon after, with same functionality.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Mar 20, 2026

no problem. proceed as you find right

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants