Skip to content

Commit 20b1dd2

Browse files
authored
Add new ensemble methods crate: linfa-ensemble (#392)
* Rebase 'ensemble_learner_pr' of github.com:hadeaninc/linfa (thanks @jk1015) * Linting * Review: replace aggregate_predictions * Refactor params and algorithm in separate files * Add documentation * Test accuracy of random forest and add sklearn doc ref * Use bagging terminology * Add ensemble and use lexycographic order in sub-crates list * Typos * Adjust test tolerance -- Co-authored-by: James Knight [email protected] Co-authored-by: James Kay [email protected]
1 parent 335924e commit 20b1dd2

File tree

10 files changed

+390
-27
lines changed

10 files changed

+390
-27
lines changed

README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,23 @@ Where does `linfa` stand right now? [Are we learning yet?](http://www.arewelearn
3030

3131
| Name | Purpose | Status | Category | Notes |
3232
| :--- | :--- | :---| :--- | :---|
33+
| [bayes](algorithms/linfa-bayes/) | Naive Bayes | Tested | Supervised learning | Contains Bernouilli, Gaussian and Multinomial Naive Bayes |
3334
| [clustering](algorithms/linfa-clustering/) | Data clustering | Tested / Benchmarked | Unsupervised learning | Clustering of unlabeled data; contains K-Means, Gaussian-Mixture-Model, DBSCAN and OPTICS |
34-
| [kernel](algorithms/linfa-kernel/) | Kernel methods for data transformation | Tested | Pre-processing | Maps feature vector into higher-dimensional space|
35-
| [linear](algorithms/linfa-linear/) | Linear regression | Tested | Partial fit | Contains Ordinary Least Squares (OLS), Generalized Linear Models (GLM) |
35+
| [ensemble](algorithms/linfa-ensemble/) | Ensemble methods | Tested | Supervised learning | Contains bagging |
3636
| [elasticnet](algorithms/linfa-elasticnet/) | Elastic Net | Tested | Supervised learning | Linear regression with elastic net constraints |
37-
| [logistic](algorithms/linfa-logistic/) | Logistic regression | Tested | Partial fit | Builds two-class logistic regression models
38-
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping, Principal Component Analysis (PCA), Random projections |
39-
| [trees](algorithms/linfa-trees/) | Decision trees | Tested / Benchmarked | Supervised learning | Linear decision trees
40-
| [svm](algorithms/linfa-svm/) | Support Vector Machines | Tested | Supervised learning | Classification or regression analysis of labeled datasets |
37+
| [ftrl](algorithms/linfa-ftrl/) | Follow The Regularized Leader - proximal | Tested / Benchmarked | Partial fit | Contains L1 and L2 regularization. Possible incremental update |
4138
| [hierarchical](algorithms/linfa-hierarchical/) | Agglomerative hierarchical clustering | Tested | Unsupervised learning | Cluster and build hierarchy of clusters |
42-
| [bayes](algorithms/linfa-bayes/) | Naive Bayes | Tested | Supervised learning | Contains Gaussian Naive Bayes |
4339
| [ica](algorithms/linfa-ica/) | Independent component analysis | Tested | Unsupervised learning | Contains FastICA implementation |
44-
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
45-
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction| Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
46-
| [preprocessing](algorithms/linfa-preprocessing/) |Normalization & Vectorization| Tested / Benchmarked | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |
40+
| [kernel](algorithms/linfa-kernel/) | Kernel methods for data transformation | Tested | Pre-processing | Maps feature vector into higher-dimensional space |
41+
| [linear](algorithms/linfa-linear/) | Linear regression | Tested | Partial fit | Contains Ordinary Least Squares (OLS), Generalized Linear Models (GLM) |
42+
| [logistic](algorithms/linfa-logistic/) | Logistic regression | Tested | Partial fit | Builds two-class logistic regression models |
4743
| [nn](algorithms/linfa-nn/) | Nearest Neighbours & Distances | Tested / Benchmarked | Pre-processing | Spatial index structures and distance functions |
48-
| [ftrl](algorithms/linfa-ftrl/) | Follow The Regularized Leader - proximal | Tested / Benchmarked | Partial fit | Contains L1 and L2 regularization. Possible incremental update |
44+
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
45+
| [preprocessing](algorithms/linfa-preprocessing/) | Normalization & Vectorization| Tested / Benchmarked | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |
46+
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping, Principal Component Analysis (PCA), Random projections |
47+
| [svm](algorithms/linfa-svm/) | Support Vector Machines | Tested | Supervised learning | Classification or regression analysis of labeled datasets |
48+
| [trees](algorithms/linfa-trees/) | Decision trees | Tested / Benchmarked | Supervised learning | Linear decision trees |
49+
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction | Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
4950

5051
We believe that only a significant community effort can nurture, build, and sustain a machine learning ecosystem in Rust - there is no other way forward.
5152

algorithms/linfa-ensemble/Cargo.toml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
[package]
2+
name = "linfa-ensemble"
3+
version = "0.7.0"
4+
edition = "2018"
5+
authors = [
6+
"James Knight <[email protected]>",
7+
"James Kay <[email protected]>",
8+
]
9+
description = "A general method for creating ensemble classifiers"
10+
license = "MIT/Apache-2.0"
11+
12+
repository = "https://github.com/rust-ml/linfa"
13+
readme = "README.md"
14+
15+
keywords = ["machine-learning", "linfa", "ensemble"]
16+
categories = ["algorithms", "mathematics", "science"]
17+
18+
[features]
19+
default = []
20+
serde = ["serde_crate", "ndarray/serde"]
21+
22+
[dependencies.serde_crate]
23+
package = "serde"
24+
optional = true
25+
version = "1.0"
26+
default-features = false
27+
features = ["std", "derive"]
28+
29+
[dependencies]
30+
ndarray = { version = "0.15", features = ["rayon", "approx"] }
31+
ndarray-rand = "0.14"
32+
rand = "0.8.5"
33+
34+
linfa = { version = "0.7.1", path = "../.." }
35+
linfa-trees = { version = "0.7.1", path = "../linfa-trees" }
36+
37+
[dev-dependencies]
38+
linfa-datasets = { version = "0.7.1", path = "../../datasets/", features = [
39+
"iris",
40+
] }

algorithms/linfa-ensemble/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Ensemble Learning
2+
3+
`linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit.
4+
5+
## The Big Picture
6+
7+
`linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.
8+
9+
## Current state
10+
11+
`linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifiers provided in linfa.
12+
13+
## Examples
14+
15+
You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use:
16+
17+
```bash
18+
$ cargo run --example randomforest_iris --release
19+
```
20+
21+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
2+
use linfa_ensemble::EnsembleLearnerParams;
3+
use linfa_trees::DecisionTree;
4+
use ndarray_rand::rand::SeedableRng;
5+
use rand::rngs::SmallRng;
6+
7+
fn main() {
8+
// Number of models in the ensemble
9+
let ensemble_size = 100;
10+
// Proportion of training data given to each model
11+
let bootstrap_proportion = 0.7;
12+
13+
// Load dataset
14+
let mut rng = SmallRng::seed_from_u64(42);
15+
let (train, test) = linfa_datasets::iris()
16+
.shuffle(&mut rng)
17+
.split_with_ratio(0.8);
18+
19+
// Train ensemble learner model
20+
let model = EnsembleLearnerParams::new(DecisionTree::params())
21+
.ensemble_size(ensemble_size)
22+
.bootstrap_proportion(bootstrap_proportion)
23+
.fit(&train)
24+
.unwrap();
25+
26+
// Return highest ranking predictions
27+
let final_predictions_ensemble = model.predict(&test);
28+
println!("Final Predictions: \n{:?}", final_predictions_ensemble);
29+
30+
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
31+
32+
println!("{:?}", cm);
33+
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}",
34+
100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion);
35+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use crate::EnsembleLearnerValidParams;
2+
use linfa::{
3+
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
4+
error::Error,
5+
traits::*,
6+
DatasetBase,
7+
};
8+
use ndarray::{Array2, Axis, Zip};
9+
use rand::Rng;
10+
use std::{cmp::Eq, collections::HashMap, hash::Hash};
11+
12+
pub struct EnsembleLearner<M> {
13+
pub models: Vec<M>,
14+
}
15+
16+
impl<M> EnsembleLearner<M> {
17+
// Generates prediction iterator returning predictions from each model
18+
pub fn generate_predictions<'b, R: Records, T>(
19+
&'b self,
20+
x: &'b R,
21+
) -> impl Iterator<Item = T> + 'b
22+
where
23+
M: Predict<&'b R, T>,
24+
{
25+
self.models.iter().map(move |m| m.predict(x))
26+
}
27+
}
28+
29+
impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
30+
where
31+
M: PredictInplace<Array2<F>, T>,
32+
<T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
33+
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
34+
{
35+
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
36+
let y_array = y.as_targets();
37+
assert_eq!(
38+
x.nrows(),
39+
y_array.len_of(Axis(0)),
40+
"The number of data points must match the number of outputs."
41+
);
42+
43+
let predictions = self.generate_predictions(x);
44+
45+
// prediction map has same shape as y_array, but the elements are maps
46+
let mut prediction_maps = y_array.map(|_| HashMap::new());
47+
48+
for prediction in predictions {
49+
let p_arr = prediction.as_targets();
50+
assert_eq!(p_arr.shape(), y_array.shape());
51+
// Insert each prediction value into the corresponding map
52+
Zip::from(&mut prediction_maps)
53+
.and(&p_arr)
54+
.for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
55+
}
56+
57+
// For each prediction, pick the result with the highest number of votes
58+
let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
59+
let mut y_array = y.as_targets_mut();
60+
for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
61+
*y = **pred
62+
}
63+
}
64+
65+
fn default_target(&self, x: &Array2<F>) -> T {
66+
self.models[0].default_target(x)
67+
}
68+
}
69+
70+
impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
71+
for EnsembleLearnerValidParams<P, R>
72+
where
73+
D: Clone,
74+
T: FromTargetArrayOwned,
75+
T::Elem: Copy + Eq + Hash,
76+
T::Owned: AsTargets,
77+
{
78+
type Object = EnsembleLearner<P::Object>;
79+
80+
fn fit(
81+
&self,
82+
dataset: &DatasetBase<Array2<D>, T>,
83+
) -> core::result::Result<Self::Object, Error> {
84+
let mut models = Vec::new();
85+
let mut rng = self.rng.clone();
86+
87+
let dataset_size =
88+
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
89+
90+
let iter = dataset.bootstrap_samples(dataset_size, &mut rng);
91+
92+
for train in iter {
93+
let model = self.model_params.fit(&train).unwrap();
94+
models.push(model);
95+
96+
if models.len() == self.ensemble_size {
97+
break;
98+
}
99+
}
100+
101+
Ok(EnsembleLearner { models })
102+
}
103+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use linfa::{
2+
error::{Error, Result},
3+
ParamGuard,
4+
};
5+
use rand::rngs::ThreadRng;
6+
use rand::Rng;
7+
8+
#[derive(Clone, Copy, Debug, PartialEq)]
9+
pub struct EnsembleLearnerValidParams<P, R> {
10+
/// The number of models in the ensemble
11+
pub ensemble_size: usize,
12+
/// The proportion of the total number of training samples that should be given to each model for training
13+
pub bootstrap_proportion: f64,
14+
/// The model parameters for the base model
15+
pub model_params: P,
16+
pub rng: R,
17+
}
18+
19+
#[derive(Clone, Copy, Debug, PartialEq)]
20+
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
21+
22+
impl<P> EnsembleLearnerParams<P, ThreadRng> {
23+
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
24+
Self::new_fixed_rng(model_params, rand::thread_rng())
25+
}
26+
}
27+
28+
impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
29+
pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
30+
Self(EnsembleLearnerValidParams {
31+
ensemble_size: 1,
32+
bootstrap_proportion: 1.0,
33+
model_params,
34+
rng,
35+
})
36+
}
37+
38+
pub fn ensemble_size(mut self, size: usize) -> Self {
39+
self.0.ensemble_size = size;
40+
self
41+
}
42+
43+
pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
44+
self.0.bootstrap_proportion = proportion;
45+
self
46+
}
47+
}
48+
49+
impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
50+
type Checked = EnsembleLearnerValidParams<P, R>;
51+
type Error = Error;
52+
53+
fn check_ref(&self) -> Result<&Self::Checked> {
54+
if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
55+
Err(Error::Parameters(format!(
56+
"Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
57+
self.0.bootstrap_proportion
58+
)))
59+
} else if self.0.ensemble_size < 1 {
60+
Err(Error::Parameters(format!(
61+
"Ensemble size should be less than one, but was {}",
62+
self.0.ensemble_size
63+
)))
64+
} else {
65+
Ok(&self.0)
66+
}
67+
}
68+
69+
fn check(self) -> Result<Self::Checked> {
70+
self.check_ref()?;
71+
Ok(self.0)
72+
}
73+
}

algorithms/linfa-ensemble/src/lib.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//! # Ensemble Learning Algorithms
2+
//!
3+
//! Ensemble methods combine the predictions of several base estimators built with a given
4+
//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5+
//!
6+
//! ## Bootstrap Aggregation (aka Bagging)
7+
//!
8+
//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
9+
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
10+
//!
11+
//! ## Reference
12+
//!
13+
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
14+
//!
15+
//! ## Example
16+
//!
17+
//! This example shows how to train a bagging model using 100 decision trees,
18+
//! each trained on 70% of the training data (bootstrap sampling).
19+
//!
20+
//! ```no_run
21+
//! use linfa::prelude::{Fit, Predict};
22+
//! use linfa_ensemble::EnsembleLearnerParams;
23+
//! use linfa_trees::DecisionTree;
24+
//! use ndarray_rand::rand::SeedableRng;
25+
//! use rand::rngs::SmallRng;
26+
//!
27+
//! // Load Iris dataset
28+
//! let mut rng = SmallRng::seed_from_u64(42);
29+
//! let (train, test) = linfa_datasets::iris()
30+
//! .shuffle(&mut rng)
31+
//! .split_with_ratio(0.8);
32+
//!
33+
//! // Train the model on the iris dataset
34+
//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
35+
//! .ensemble_size(100)
36+
//! .bootstrap_proportion(0.7)
37+
//! .fit(&train)
38+
//! .unwrap();
39+
//!
40+
//! // Make predictions on the test set
41+
//! let predictions = bagging_model.predict(&test);
42+
//! ```
43+
//!
44+
mod algorithm;
45+
mod hyperparams;
46+
47+
pub use algorithm::*;
48+
pub use hyperparams::*;
49+
50+
#[cfg(test)]
51+
mod tests {
52+
use super::*;
53+
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
54+
use linfa_trees::DecisionTree;
55+
use ndarray_rand::rand::SeedableRng;
56+
use rand::rngs::SmallRng;
57+
58+
#[test]
59+
fn test_ensemble_learner_accuracy_on_iris_dataset() {
60+
let mut rng = SmallRng::seed_from_u64(42);
61+
let (train, test) = linfa_datasets::iris()
62+
.shuffle(&mut rng)
63+
.split_with_ratio(0.8);
64+
65+
let model = EnsembleLearnerParams::new(DecisionTree::params())
66+
.ensemble_size(100)
67+
.bootstrap_proportion(0.7)
68+
.fit(&train)
69+
.unwrap();
70+
71+
let predictions = model.predict(&test);
72+
73+
let cm = predictions.confusion_matrix(&test).unwrap();
74+
let acc = cm.accuracy();
75+
assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
76+
}
77+
}

0 commit comments

Comments
 (0)