Skip to content
230 changes: 230 additions & 0 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,95 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
samples
}

/// Predict class probabilities for a single input sample.
///
/// This method averages the probability estimates from all trees in the forest.
/// Each tree returns a probability distribution based on the class distribution
/// in its leaf node (scikit-learn style), and these distributions are averaged
/// across all trees to produce the final probability estimate.
///
/// # Arguments
///
/// * `x` - The input matrix containing all samples.
/// * `row` - The index of the row in `x` for which to predict probabilities.
///
/// # Returns
///
/// A vector of probabilities, one for each class. The sum of probabilities equals 1.0.
/// Each probability represents the average fraction of training samples of that class
/// across all trees that reached the same leaf node for this input.
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
// improvement: unwrap делаем один раз
let trees = self.trees.as_ref().unwrap();

// improvement: unwrap classes тоже один раз
let k = self.classes.as_ref().unwrap().len();

let mut probs = vec![0.0; k];

for tree in trees {
let tree_probs = tree.predict_proba_for_row_real(x, row);

// improvement: убран range loop
// improvement: нет индексирования
// improvement: zip гарантирует покомпонентное сложение
for (p, tp) in probs.iter_mut().zip(tree_probs.iter()) {
*p += *tp; // важно разыменование
}
}

// improvement: unwrap уже не нужен
let n_trees = trees.len() as f64;

// improvement: убран needless_range_loop
for p in &mut probs {
*p /= n_trees;
}

probs
}

/// Predict class probabilities for the input samples.
///
/// This method returns probability estimates for each sample in the input matrix.
/// For each sample, probabilities are computed by averaging the predictions from
/// all trees in the forest. Each tree contributes a probability distribution based
/// on the class distribution in its leaf node.
///
/// This is the scikit-learn style `predict_proba` behavior, providing calibrated
/// probability estimates rather than just class predictions.
///
/// # Arguments
///
/// * `x` - The input samples as a matrix where each row is a sample and each column
/// is a feature.
///
/// # Returns
///
/// A `Result` containing a `Vec<Vec<f64>>` where each inner vector corresponds to
/// a sample and contains probabilities for each class. The sum of probabilities
/// for each sample equals 1.0.
///
/// # Note
///
/// Return type is `Vec<Vec<f64>>` for minimal API changes. The tree classifier
/// returns `DenseMatrix<f64>` for the same method.
///
/// # Errors
///
/// Returns an error if the forest has not been fitted (trees are None).
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
let (n, _) = x.shape();

let mut result = Vec::with_capacity(n);

for i in 0..n {
result.push(self.predict_proba_for_row(x, i));
}

Ok(result)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -806,4 +895,145 @@ mod tests {

assert_eq!(forest, deserialized_forest);
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_predict_proba_iris() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];

let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 87,
},
)
.unwrap();

let probabilities = classifier.predict_proba(&x).unwrap();

// Check 1: dimensions
assert_eq!(probabilities.len(), 20);
assert_eq!(probabilities[0].len(), 2);

// Check 2: probabilities sum to 1.0 for all rows
for row in 0..20 {
let row_sum: f64 = probabilities[row].iter().sum();
assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row {} probabilities should sum to 1, got {}",
row,
row_sum
);
}

// Check 3: first 8 samples → higher prob for class 0
for i in 0..8 {
assert!(
probabilities[i][0] > probabilities[i][1],
"Sample {} should have higher prob for class 0",
i
);
}

// Check 4: last 12 samples → higher prob for class 1
for i in 8..20 {
assert!(
probabilities[i][1] > probabilities[i][0],
"Sample {} should have higher prob for class 1",
i
);
}
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_predict_proba_iris_mixed_leaves() {
// Dataset with mixed leaves
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[5.1, 3.5, 1.4, 0.2], // Same features
&[5.1, 3.5, 1.4, 0.2], // Same features
&[7.0, 3.2, 4.7, 1.4],
&[7.0, 3.2, 4.7, 1.4], // Same features
])
.unwrap();
let y = vec![0, 0, 1, 1, 1]; // Mixed classes in same feature region

let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
n_trees: 5,
seed: 42,
..Default::default()
},
)
.unwrap();

let probabilities = classifier.predict_proba(&x).unwrap();

// Check 1: All probabilities should be valid
for row in 0..5 {
let sum: f64 = probabilities[row].iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Probabilities for row {} should sum to 1.0",
row
);
for &p in &probabilities[row] {
assert!(p >= 0.0 && p <= 1.0, "Probability out of range");
}
}

// Check 2: First 3 samples must have non-zero prob for both classes, since they are mixed
for i in 0..3 {
assert!(
probabilities[i][0] > 0.0,
"Sample {} should have non-zero prob for class 0",
i
);
assert!(
probabilities[i][1] > 0.0,
"Sample {} should have non-zero prob for class 1",
i
);
}
}
}
Loading
Loading