diff --git a/algorithms/linfa-reduction/Cargo.toml b/algorithms/linfa-reduction/Cargo.toml index c1de0c2f9..6143a8a60 100644 --- a/algorithms/linfa-reduction/Cargo.toml +++ b/algorithms/linfa-reduction/Cargo.toml @@ -55,4 +55,4 @@ linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ ] } approx = { version = "0.4" } mnist = { version = "0.6.0", features = ["download"] } -linfa-trees = { version = "0.7.1", path = "../linfa-trees" } +linfa-trees = { version = "0.8.0", path = "../linfa-trees" } diff --git a/algorithms/linfa-trees/Cargo.toml b/algorithms/linfa-trees/Cargo.toml index a270ec6af..91f1efd1f 100644 --- a/algorithms/linfa-trees/Cargo.toml +++ b/algorithms/linfa-trees/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "linfa-trees" -version = "0.7.1" +version = "0.8.0" edition = "2018" -authors = ["Moss Ebeling "] +authors = [ + "Moss Ebeling ", + "Abhinav Shukla ", +] description = "A collection of tree-based algorithms" license = "MIT OR Apache-2.0" @@ -28,6 +31,8 @@ ndarray = { version = "0.15" , features = ["rayon", "approx"]} ndarray-rand = "0.14" linfa = { version = "0.7.1", path = "../.." } +rand = "0.8" + [dev-dependencies] rand = { version = "0.8", features = ["small_rng"] } diff --git a/algorithms/linfa-trees/README.md b/algorithms/linfa-trees/README.md index a25f2ed5e..fc185a49d 100644 --- a/algorithms/linfa-trees/README.md +++ b/algorithms/linfa-trees/README.md @@ -10,23 +10,66 @@ Decision Trees (DTs) are a non-parametric supervised learning method used for cl ## Current state -`linfa-trees` currently provides an implementation of single tree fitting +`linfa-trees` currently provides an implementation of single tree fitting. -## Examples +## Random Forest Classifier -There is an example in the `examples/` directory showing how to use decision trees. To run, use: +An ensemble of decision trees trained on **bootstrapped** subsets of the data **and** **feature-subsampled** per tree. Predictions are made by majority voting across all trees, which typically improves generalization and robustness over a single tree. + +**Key features:** +- Configurable number of trees (`n_trees`) +- Optional maximum depth (`max_depth`) +- Fraction of features sampled per tree (`feature_subsample`) +- Reproducible results via RNG seed (`seed`) +- Implements `Fit` and `Predict` traits for seamless integration + +### Random Forest Example + +```rust +use linfa::prelude::*; +use linfa_datasets::iris; +use linfa_trees::RandomForestParams; +use rand::thread_rng; + +fn main() -> Result<(), Box> { + // Load and split the Iris dataset + let (train, valid) = iris() + .shuffle(&mut thread_rng()) + .split_with_ratio(0.8); + + // Train a Random Forest with 100 trees, depth 10, and 80% feature sampling + let model = RandomForestParams::new(100) + .max_depth(Some(10)) + .feature_subsample(0.8) + .seed(42) + .fit(&train)?; + + // Predict and evaluate + let preds = model.predict(valid.records.clone()); + let cm = preds.confusion_matrix(&valid)?; + println!("Accuracy: {:.2}", cm.accuracy()); + Ok(()) +} +``` + +Run this example with: + +```bash + +cargo run --release --example iris_random_forest +``` +Examples +There is an example in the examples/ directory showing how to use decision trees. To run, use: ```bash -$ cargo run --release --example decision_tree +cargo run --release --example decision_tree ``` -This generates the following tree: +This generates the following tree: -

- -

+

Iris decision tree

+License +Dual‐licensed to be compatible with the Rust project. -## License -Dual-licensed to be compatible with the Rust project. +Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. -Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. diff --git a/algorithms/linfa-trees/examples/iris_random_forest.rs b/algorithms/linfa-trees/examples/iris_random_forest.rs new file mode 100644 index 000000000..c645de30f --- /dev/null +++ b/algorithms/linfa-trees/examples/iris_random_forest.rs @@ -0,0 +1,35 @@ +// File: examples/iris_random_forest.rs + +use linfa::prelude::*; +use linfa_datasets::iris; +use linfa_trees::{DecisionTree, RandomForestParams}; +use rand::thread_rng; + +fn main() -> Result<(), Box> { + // Create an RNG for reproducible shuffling + let mut rng = thread_rng(); + + // 1. Load, shuffle, and split the Iris dataset (80% train, 20% valid) + let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8); + + // 2. Single‐tree baseline + let dt_model = DecisionTree::params() + .max_depth(None) // no depth limit + .fit(&train)?; + let dt_preds = dt_model.predict(valid.records.clone()); + let dt_cm = dt_preds.confusion_matrix(&valid)?; + println!("Single‐tree accuracy: {:.2}", dt_cm.accuracy()); + + // 3. Random Forest + let rf_model = RandomForestParams::new(50) + .max_depth(Some(5)) + .feature_subsample(0.7) + .seed(42) // fix RNG seed for reproducibility + .fit(&train)?; + let rf_preds = rf_model.predict(valid.records.clone()); + let rf_cm = rf_preds.confusion_matrix(&valid)?; + println!("Random‐forest accuracy: {:.2}", rf_cm.accuracy()); + + // 4. Exit cleanly + Ok(()) +} diff --git a/algorithms/linfa-trees/src/decision_trees/mod.rs b/algorithms/linfa-trees/src/decision_trees/mod.rs index 1192ff146..93b7b221c 100644 --- a/algorithms/linfa-trees/src/decision_trees/mod.rs +++ b/algorithms/linfa-trees/src/decision_trees/mod.rs @@ -7,3 +7,4 @@ pub use algorithm::*; pub use hyperparams::*; pub use iter::*; pub use tikz::*; +pub mod random_forest; diff --git a/algorithms/linfa-trees/src/decision_trees/random_forest.rs b/algorithms/linfa-trees/src/decision_trees/random_forest.rs new file mode 100644 index 000000000..1c95d203f --- /dev/null +++ b/algorithms/linfa-trees/src/decision_trees/random_forest.rs @@ -0,0 +1,168 @@ +//! Random Forest Classifier +//! +//! An ensemble of decision trees trained on bootstrapped, feature‐subsampled slices of the data. + +use linfa::prelude::*; +use linfa::{error::Error, Float, ParamGuard}; +use ndarray::{Array1, Array2, Axis}; +use rand::{rngs::StdRng, seq::index::sample, Rng, SeedableRng}; +use std::marker::PhantomData; + +use super::algorithm::DecisionTree; + +#[derive(Debug, Clone)] +pub struct RandomForestClassifier { + trees: Vec>, + feature_indices: Vec>, + _phantom: PhantomData, +} + +#[derive(Debug, Clone)] +pub struct RandomForestParams { + inner: RandomForestValidParams, +} + +#[derive(Debug, Clone)] +pub struct RandomForestValidParams { + pub n_trees: usize, + pub max_depth: Option, + pub feature_subsample: f32, + pub seed: u64, + _phantom: PhantomData, +} + +impl RandomForestParams { + /// Create a new random forest with the specified number of trees. + pub fn new(n_trees: usize) -> Self { + Self { + inner: RandomForestValidParams { + n_trees, + max_depth: None, + feature_subsample: 1.0, + seed: 42, + _phantom: PhantomData, + }, + } + } + + /// Set the maximum depth of each tree. + pub fn max_depth(mut self, depth: Option) -> Self { + self.inner.max_depth = depth; + self + } + + /// Fraction of features to sample per tree (0.0–1.0). + pub fn feature_subsample(mut self, ratio: f32) -> Self { + self.inner.feature_subsample = ratio; + self + } + + /// RNG seed for reproducibility. + pub fn seed(mut self, seed: u64) -> Self { + self.inner.seed = seed; + self + } +} + +impl ParamGuard for RandomForestParams { + type Checked = RandomForestValidParams; + type Error = Error; + + fn check_ref(&self) -> Result<&Self::Checked, Self::Error> { + if self.inner.n_trees == 0 { + return Err(Error::Parameters("n_trees must be > 0".into())); + } + if !(0.0..=1.0).contains(&self.inner.feature_subsample) { + return Err(Error::Parameters( + "feature_subsample must be in [0, 1]".into(), + )); + } + Ok(&self.inner) + } + + fn check(self) -> Result { + self.check_ref()?; + Ok(self.inner) + } +} + +/// Bootstrap‐sample the dataset with replacement. +fn bootstrap( + dataset: &DatasetBase, Array1>, + rng: &mut impl Rng, +) -> DatasetBase, Array1> { + let n = dataset.nsamples(); + let indices: Vec = (0..n).map(|_| rng.gen_range(0..n)).collect(); + let rec = dataset.records.select(Axis(0), &indices); + let tgt = dataset.targets.select(Axis(0), &indices); + Dataset::new(rec, tgt) +} + +impl Fit, Array1, Error> for RandomForestValidParams { + type Object = RandomForestClassifier; + + fn fit(&self, dataset: &DatasetBase, Array1>) -> Result { + let mut rng = StdRng::seed_from_u64(self.seed); + let mut trees = Vec::with_capacity(self.n_trees); + let mut feats_list = Vec::with_capacity(self.n_trees); + + let n_features = dataset.records.ncols(); + let n_sub = ((n_features as f32) * self.feature_subsample).ceil() as usize; + + for _ in 0..self.n_trees { + // 1) Bootstrap rows + let sample_set = bootstrap(dataset, &mut rng); + + // 2) Random feature subset + let feats = sample(&mut rng, n_features, n_sub) + .into_iter() + .collect::>(); + feats_list.push(feats.clone()); + + // 3) Train a tree on that feature slice + let sub_rec = sample_set.records.select(Axis(1), &feats); + let sub_ds = Dataset::new(sub_rec, sample_set.targets.clone()); + let tree = DecisionTree::params() + .max_depth(self.max_depth) + .fit(&sub_ds)?; + trees.push(tree); + } + + Ok(RandomForestClassifier { + trees, + feature_indices: feats_list, + _phantom: PhantomData, + }) + } +} + +impl Predict, Array1> for RandomForestClassifier { + fn predict(&self, x: Array2) -> Array1 { + let n = x.nrows(); + // adjust 100 to the expected number of classes if known + let mut votes = vec![vec![0; n]; 100]; + + for (tree, feats) in self.trees.iter().zip(&self.feature_indices) { + // Slice test data to the features this tree saw + let sub_x = x.select(Axis(1), feats); + let preds: Array1 = tree.predict(&sub_x); + for (i, &c) in preds.iter().enumerate() { + votes[c][i] += 1; + } + } + + // Majority vote per sample + Array1::from( + (0..n) + .map(|i| { + votes + .iter() + .enumerate() + .max_by_key(|(_, v)| v[i]) + .map(|(lbl, _)| lbl) + .unwrap_or(0) + }) + .collect::>(), + ) + } +} diff --git a/algorithms/linfa-trees/src/lib.rs b/algorithms/linfa-trees/src/lib.rs index 3440e9004..474406de2 100644 --- a/algorithms/linfa-trees/src/lib.rs +++ b/algorithms/linfa-trees/src/lib.rs @@ -1,7 +1,7 @@ //! //! # Decision tree learning -//! `linfa-trees` aims to provide pure rust implementations -//! of decison trees learning algorithms. +//! `linfa-trees` aims to provide pure Rust implementations +//! of decision tree learning algorithms. //! //! # The big picture //! @@ -19,5 +19,11 @@ mod decision_trees; +// Re-export all core decision tree functionality pub use decision_trees::*; + +// Explicitly export the Random Forest classifier API +pub use decision_trees::random_forest::{RandomForestClassifier, RandomForestParams}; + +// Re-export the common Result alias for convenience pub use linfa::error::Result; diff --git a/algorithms/linfa-trees/tests/random_forest.rs b/algorithms/linfa-trees/tests/random_forest.rs new file mode 100644 index 000000000..921893cfc --- /dev/null +++ b/algorithms/linfa-trees/tests/random_forest.rs @@ -0,0 +1,33 @@ +// linfa-trees/tests/random_forest.rs + +use linfa::prelude::*; +use linfa_datasets::iris; +use linfa_trees::RandomForestParams; +use rand::rngs::StdRng; +use rand::SeedableRng; + +#[test] +fn iris_random_forest_high_accuracy() { + // reproducible split + let mut rng = StdRng::seed_from_u64(42); + let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8); + + let model = RandomForestParams::new(100) + .max_depth(Some(10)) + .feature_subsample(0.8) + .seed(42) + .fit(&train) + .expect("Training failed"); + + let preds = model.predict(valid.records.clone()); + let cm = preds + .confusion_matrix(&valid) + .expect("Failed to compute confusion matrix"); + + let accuracy = cm.accuracy(); + assert!( + accuracy >= 0.9, + "Expected ≥90% accuracy on Iris, got {:.2}", + accuracy + ); +}