Skip to content

[Feature] Add RandomForestClassifier to linfa-trees #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion algorithms/linfa-reduction/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [
] }
approx = { version = "0.4" }
mnist = { version = "0.6.0", features = ["download"] }
linfa-trees = { version = "0.7.1", path = "../linfa-trees" }
linfa-trees = { version = "0.8.0", path = "../linfa-trees" }
9 changes: 7 additions & 2 deletions algorithms/linfa-trees/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
[package]
name = "linfa-trees"
version = "0.7.1"
version = "0.8.0"
edition = "2018"
authors = ["Moss Ebeling <[email protected]>"]
authors = [
"Moss Ebeling <[email protected]>",
"Abhinav Shukla <[email protected]>",
]
description = "A collection of tree-based algorithms"
license = "MIT OR Apache-2.0"

Expand All @@ -28,6 +31,8 @@ ndarray = { version = "0.15" , features = ["rayon", "approx"]}
ndarray-rand = "0.14"

linfa = { version = "0.7.1", path = "../.." }
rand = "0.8"


[dev-dependencies]
rand = { version = "0.8", features = ["small_rng"] }
Expand Down
65 changes: 54 additions & 11 deletions algorithms/linfa-trees/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,66 @@ Decision Trees (DTs) are a non-parametric supervised learning method used for cl

## Current state

`linfa-trees` currently provides an implementation of single tree fitting
`linfa-trees` currently provides an implementation of single tree fitting.

## Examples
## Random Forest Classifier

There is an example in the `examples/` directory showing how to use decision trees. To run, use:
An ensemble of decision trees trained on **bootstrapped** subsets of the data **and** **feature-subsampled** per tree. Predictions are made by majority voting across all trees, which typically improves generalization and robustness over a single tree.

**Key features:**
- Configurable number of trees (`n_trees`)
- Optional maximum depth (`max_depth`)
- Fraction of features sampled per tree (`feature_subsample`)
- Reproducible results via RNG seed (`seed`)
- Implements `Fit` and `Predict` traits for seamless integration

### Random Forest Example

```rust
use linfa::prelude::*;
use linfa_datasets::iris;
use linfa_trees::RandomForestParams;
use rand::thread_rng;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load and split the Iris dataset
let (train, valid) = iris()
.shuffle(&mut thread_rng())
.split_with_ratio(0.8);

// Train a Random Forest with 100 trees, depth 10, and 80% feature sampling
let model = RandomForestParams::new(100)
.max_depth(Some(10))
.feature_subsample(0.8)
.seed(42)
.fit(&train)?;

// Predict and evaluate
let preds = model.predict(valid.records.clone());
let cm = preds.confusion_matrix(&valid)?;
println!("Accuracy: {:.2}", cm.accuracy());
Ok(())
}
```

Run this example with:

```bash

cargo run --release --example iris_random_forest
```
Examples
There is an example in the examples/ directory showing how to use decision trees. To run, use:

```bash
$ cargo run --release --example decision_tree
cargo run --release --example decision_tree
```

This generates the following tree:
This generates the following tree:

<p align="center">
<img src="./iris-decisiontree.svg">
</p>
<p align="center"> <img src="./iris-decisiontree.svg" alt="Iris decision tree"> </p>
License
Dual‐licensed to be compatible with the Rust project.

## License
Dual-licensed to be compatible with the Rust project.
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.

Licensed under the Apache License, Version 2.0 <http://www.apache.org/licenses/LICENSE-2.0> or the MIT license <http://opensource.org/licenses/MIT>, at your option. This file may not be copied, modified, or distributed except according to those terms.
35 changes: 35 additions & 0 deletions algorithms/linfa-trees/examples/iris_random_forest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// File: examples/iris_random_forest.rs

use linfa::prelude::*;
use linfa_datasets::iris;
use linfa_trees::{DecisionTree, RandomForestParams};
use rand::thread_rng;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create an RNG for reproducible shuffling
let mut rng = thread_rng();

// 1. Load, shuffle, and split the Iris dataset (80% train, 20% valid)
let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8);

// 2. Single‐tree baseline
let dt_model = DecisionTree::params()
.max_depth(None) // no depth limit
.fit(&train)?;
let dt_preds = dt_model.predict(valid.records.clone());
let dt_cm = dt_preds.confusion_matrix(&valid)?;
println!("Single‐tree accuracy: {:.2}", dt_cm.accuracy());

// 3. Random Forest
let rf_model = RandomForestParams::new(50)
.max_depth(Some(5))
.feature_subsample(0.7)
.seed(42) // fix RNG seed for reproducibility
.fit(&train)?;
let rf_preds = rf_model.predict(valid.records.clone());
let rf_cm = rf_preds.confusion_matrix(&valid)?;
println!("Random‐forest accuracy: {:.2}", rf_cm.accuracy());

// 4. Exit cleanly
Ok(())
}
1 change: 1 addition & 0 deletions algorithms/linfa-trees/src/decision_trees/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub use algorithm::*;
pub use hyperparams::*;
pub use iter::*;
pub use tikz::*;
pub mod random_forest;
168 changes: 168 additions & 0 deletions algorithms/linfa-trees/src/decision_trees/random_forest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
//! Random Forest Classifier
//!
//! An ensemble of decision trees trained on bootstrapped, feature‐subsampled slices of the data.

use linfa::prelude::*;
use linfa::{error::Error, Float, ParamGuard};
use ndarray::{Array1, Array2, Axis};
use rand::{rngs::StdRng, seq::index::sample, Rng, SeedableRng};
use std::marker::PhantomData;

use super::algorithm::DecisionTree;

#[derive(Debug, Clone)]
pub struct RandomForestClassifier<F: Float> {
trees: Vec<DecisionTree<F, usize>>,
feature_indices: Vec<Vec<usize>>,
_phantom: PhantomData<F>,
}

#[derive(Debug, Clone)]
pub struct RandomForestParams<F: Float> {
inner: RandomForestValidParams<F>,
}

#[derive(Debug, Clone)]
pub struct RandomForestValidParams<F: Float> {
pub n_trees: usize,
pub max_depth: Option<usize>,
pub feature_subsample: f32,
pub seed: u64,
_phantom: PhantomData<F>,
}

impl<F: Float> RandomForestParams<F> {
/// Create a new random forest with the specified number of trees.
pub fn new(n_trees: usize) -> Self {

Check warning on line 36 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L36

Added line #L36 was not covered by tests
Self {
inner: RandomForestValidParams {

Check warning on line 38 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L38

Added line #L38 was not covered by tests
n_trees,
max_depth: None,
feature_subsample: 1.0,
seed: 42,
_phantom: PhantomData,
},
}
}

/// Set the maximum depth of each tree.
pub fn max_depth(mut self, depth: Option<usize>) -> Self {
self.inner.max_depth = depth;
self

Check warning on line 51 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L49-L51

Added lines #L49 - L51 were not covered by tests
}

/// Fraction of features to sample per tree (0.0–1.0).
pub fn feature_subsample(mut self, ratio: f32) -> Self {
self.inner.feature_subsample = ratio;
self

Check warning on line 57 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L55-L57

Added lines #L55 - L57 were not covered by tests
}

/// RNG seed for reproducibility.
pub fn seed(mut self, seed: u64) -> Self {
self.inner.seed = seed;

Check warning on line 62 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L61-L62

Added lines #L61 - L62 were not covered by tests
self
}
}

impl<F: Float> ParamGuard for RandomForestParams<F> {
type Checked = RandomForestValidParams<F>;
type Error = Error;

fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {

Check warning on line 71 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L71

Added line #L71 was not covered by tests
if self.inner.n_trees == 0 {
return Err(Error::Parameters("n_trees must be > 0".into()));

Check warning on line 73 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L73

Added line #L73 was not covered by tests
}
if !(0.0..=1.0).contains(&self.inner.feature_subsample) {
return Err(Error::Parameters(
"feature_subsample must be in [0, 1]".into(),

Check warning on line 77 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L75-L77

Added lines #L75 - L77 were not covered by tests
));
}
Ok(&self.inner)

Check warning on line 80 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L80

Added line #L80 was not covered by tests
}

fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.inner)

Check warning on line 85 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L83-L85

Added lines #L83 - L85 were not covered by tests
}
}

/// Bootstrap‐sample the dataset with replacement.
fn bootstrap<F: Float>(
dataset: &DatasetBase<Array2<F>, Array1<usize>>,
rng: &mut impl Rng,
) -> DatasetBase<Array2<F>, Array1<usize>> {
let n = dataset.nsamples();

Check warning on line 94 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L94

Added line #L94 was not covered by tests
let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
let rec = dataset.records.select(Axis(0), &indices);
let tgt = dataset.targets.select(Axis(0), &indices);
Dataset::new(rec, tgt)

Check warning on line 98 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L98

Added line #L98 was not covered by tests
}

impl<F: Float + Send + Sync> Fit<Array2<F>, Array1<usize>, Error> for RandomForestValidParams<F> {
type Object = RandomForestClassifier<F>;

fn fit(&self, dataset: &DatasetBase<Array2<F>, Array1<usize>>) -> Result<Self::Object, Error> {
let mut rng = StdRng::seed_from_u64(self.seed);
let mut trees = Vec::with_capacity(self.n_trees);
let mut feats_list = Vec::with_capacity(self.n_trees);

Check warning on line 107 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L107

Added line #L107 was not covered by tests

let n_features = dataset.records.ncols();

Check warning on line 109 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L109

Added line #L109 was not covered by tests
let n_sub = ((n_features as f32) * self.feature_subsample).ceil() as usize;

for _ in 0..self.n_trees {

Check warning on line 112 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L112

Added line #L112 was not covered by tests
// 1) Bootstrap rows
let sample_set = bootstrap(dataset, &mut rng);

// 2) Random feature subset
let feats = sample(&mut rng, n_features, n_sub)
.into_iter()
.collect::<Vec<_>>();
feats_list.push(feats.clone());

Check warning on line 120 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L120

Added line #L120 was not covered by tests

// 3) Train a tree on that feature slice
let sub_rec = sample_set.records.select(Axis(1), &feats);
let sub_ds = Dataset::new(sub_rec, sample_set.targets.clone());
let tree = DecisionTree::params()
.max_depth(self.max_depth)
.fit(&sub_ds)?;
trees.push(tree);

Check warning on line 128 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L126-L128

Added lines #L126 - L128 were not covered by tests
}

Ok(RandomForestClassifier {
trees,
feature_indices: feats_list,
_phantom: PhantomData,

Check warning on line 134 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L134

Added line #L134 was not covered by tests
})
}
}

impl<F: Float> Predict<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
fn predict(&self, x: Array2<F>) -> Array1<usize> {
let n = x.nrows();

Check warning on line 141 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L141

Added line #L141 was not covered by tests
// adjust 100 to the expected number of classes if known
let mut votes = vec![vec![0; n]; 100];

Check warning on line 143 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L143

Added line #L143 was not covered by tests

for (tree, feats) in self.trees.iter().zip(&self.feature_indices) {

Check warning on line 145 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L145

Added line #L145 was not covered by tests
// Slice test data to the features this tree saw
let sub_x = x.select(Axis(1), feats);
let preds: Array1<usize> = tree.predict(&sub_x);
for (i, &c) in preds.iter().enumerate() {
votes[c][i] += 1;
}
}

// Majority vote per sample
Array1::from(
(0..n)
.map(|i| {
votes
.iter()
.enumerate()

Check warning on line 160 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L156-L160

Added lines #L156 - L160 were not covered by tests
.max_by_key(|(_, v)| v[i])
.map(|(lbl, _)| lbl)
.unwrap_or(0)

Check warning on line 163 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L162-L163

Added lines #L162 - L163 were not covered by tests
})
.collect::<Vec<_>>(),

Check warning on line 165 in algorithms/linfa-trees/src/decision_trees/random_forest.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/random_forest.rs#L165

Added line #L165 was not covered by tests
)
}
}
10 changes: 8 additions & 2 deletions algorithms/linfa-trees/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//!
//! # Decision tree learning
//! `linfa-trees` aims to provide pure rust implementations
//! of decison trees learning algorithms.
//! `linfa-trees` aims to provide pure Rust implementations
//! of decision tree learning algorithms.
//!
//! # The big picture
//!
Expand All @@ -19,5 +19,11 @@

mod decision_trees;

// Re-export all core decision tree functionality
pub use decision_trees::*;

// Explicitly export the Random Forest classifier API
pub use decision_trees::random_forest::{RandomForestClassifier, RandomForestParams};

// Re-export the common Result alias for convenience
pub use linfa::error::Result;
33 changes: 33 additions & 0 deletions algorithms/linfa-trees/tests/random_forest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// linfa-trees/tests/random_forest.rs

use linfa::prelude::*;
use linfa_datasets::iris;
use linfa_trees::RandomForestParams;
use rand::rngs::StdRng;
use rand::SeedableRng;

#[test]
fn iris_random_forest_high_accuracy() {
// reproducible split
let mut rng = StdRng::seed_from_u64(42);
let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8);

let model = RandomForestParams::new(100)
.max_depth(Some(10))
.feature_subsample(0.8)
.seed(42)
.fit(&train)
.expect("Training failed");

let preds = model.predict(valid.records.clone());
let cm = preds
.confusion_matrix(&valid)
.expect("Failed to compute confusion matrix");

let accuracy = cm.accuracy();
assert!(
accuracy >= 0.9,
"Expected ≥90% accuracy on Iris, got {:.2}",
accuracy
);
}
Loading