diff --git a/.github/workflows/checking.yml b/.github/workflows/checking.yml index 04243116b..001be4382 100644 --- a/.github/workflows/checking.yml +++ b/.github/workflows/checking.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: toolchain: - - 1.81.0 + - 1.82.0 - stable - nightly os: diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2eefebb5f..45b633773 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: toolchain: - - 1.81.0 + - 1.82.0 - stable os: - ubuntu-latest @@ -35,7 +35,7 @@ jobs: fail-fast: false matrix: toolchain: - - 1.81.0 + - 1.82.0 - stable os: - ubuntu-latest diff --git a/algorithms/linfa-bayes/README.md b/algorithms/linfa-bayes/README.md index 2fc57679d..ba11e116a 100644 --- a/algorithms/linfa-bayes/README.md +++ b/algorithms/linfa-bayes/README.md @@ -11,7 +11,8 @@ `linfa-bayes` currently provides an implementation of the following methods: - Gaussian Naive Bayes ([`GaussianNb`]) -- Multinomial Naive Nayes ([`MultinomialNb`])) +- Multinomial Naive Nayes ([`MultinomialNb`]) +- Bernoulli Naive Nayes ([`BernoulliNb`]) ## Examples @@ -95,3 +96,43 @@ println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); # Result::Ok(()) ``` + +To run Bernoulli Naive Bayes example, use: + +```bash +$ cargo run --example winequality_bernoulli --release +``` + +
+ +Show source code + + +```rust, no_run +use linfa::metrics::ToConfusionMatrix; +use linfa::traits::{Fit, Predict}; +use linfa_bayes::{BernoulliNb, Result}; + +// Read in the dataset and convert targets to binary data +let (train, valid) = linfa_datasets::winequality() + .map_targets(|x| if *x > 6 { "good" } else { "bad" }) + .split_with_ratio(0.9); + +// Train the model +let model = BernoulliNb::params().fit(&train)?; + +// Predict the validation dataset +let pred = model.predict(&valid); + +// Construct confusion matrix +let cm = pred.confusion_matrix(&valid)?; +// classes | bad | good +// bad | 142 | 0 +// good | 17 | 0 + +// accuracy 0.8930818, MCC NaN +println!("{:?}", cm); +println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); +# Result::Ok(()) +``` +
diff --git a/algorithms/linfa-bayes/examples/winequality_bernouilli.rs b/algorithms/linfa-bayes/examples/winequality_bernouilli.rs new file mode 100644 index 000000000..53790726c --- /dev/null +++ b/algorithms/linfa-bayes/examples/winequality_bernouilli.rs @@ -0,0 +1,28 @@ +use linfa::metrics::ToConfusionMatrix; +use linfa::traits::{Fit, Predict}; +use linfa_bayes::{BernoulliNb, Result}; + +fn main() -> Result<()> { + // Read in the dataset and convert targets to binary data + let (train, valid) = linfa_datasets::winequality() + .map_targets(|x| if *x > 6 { "good" } else { "bad" }) + .split_with_ratio(0.9); + + // Train the model + let model = BernoulliNb::params().fit(&train)?; + + // Predict the validation dataset + let pred = model.predict(&valid); + + // Construct confusion matrix + let cm = pred.confusion_matrix(&valid)?; + // classes | bad | good + // bad | 142 | 0 + // good | 17 | 0 + + // accuracy 0.8930818, MCC + println!("{:?}", cm); + println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); + + Ok(()) +} diff --git a/algorithms/linfa-bayes/src/base_nb.rs b/algorithms/linfa-bayes/src/base_nb.rs index d54fdc44b..f0b87ce2f 100644 --- a/algorithms/linfa-bayes/src/base_nb.rs +++ b/algorithms/linfa-bayes/src/base_nb.rs @@ -1,4 +1,4 @@ -use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2}; +use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip}; use ndarray_stats::QuantileExt; use std::collections::HashMap; @@ -8,13 +8,17 @@ use linfa::traits::FitWith; use linfa::{Float, Label}; // Trait computing predictions for fitted Naive Bayes models -pub(crate) trait NaiveBayes<'a, F, L> +pub trait NaiveBayes<'a, F, L> where F: Float, L: Label + Ord, { + /// Compute the unnormalized posterior log probabilities. + /// The result is returned as an HashMap indexing log probabilities for each samples (eg x rows) by classes + /// (eg jll\[class\] -> (n_samples,) array) fn joint_log_likelihood(&self, x: ArrayView2) -> HashMap<&L, Array1>; + #[doc(hidden)] fn predict_inplace>(&self, x: &ArrayBase, y: &mut Array1) { assert_eq!( x.nrows(), @@ -45,6 +49,40 @@ where classes[i].clone() }); } + + /// Compute log-probability estimates for each sample wrt classes. + /// The columns corresponds to classes in sorted order returned as the second output. + fn predict_log_proba(&self, x: ArrayView2) -> (Array2, Vec<&L>) { + let log_likelihood = self.joint_log_likelihood(x); + + let mut classes = log_likelihood.keys().cloned().collect::>(); + classes.sort(); + + let n_samples = x.nrows(); + let n_classes = log_likelihood.len(); + let mut log_prob_mat = Array2::::zeros((n_samples, n_classes)); + + Zip::from(log_prob_mat.columns_mut()) + .and(&classes) + .for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap())); + + let log_prob_x = log_prob_mat + .mapv(|x| x.exp()) + .sum_axis(Axis(1)) + .mapv(|x| x.ln()) + .into_shape((n_samples, 1)) + .unwrap(); + + (log_prob_mat - log_prob_x, classes) + } + + /// Compute probability estimates for each sample wrt classes. + /// The columns corresponds to classes in sorted order returned as the second output. + fn predict_proba(&self, x: ArrayView2) -> (Array2, Vec<&L>) { + let (log_prob_mat, classes) = self.predict_log_proba(x); + + (log_prob_mat.mapv(|v| v.exp()), classes) + } } // Common functionality for hyper-parameter sets of Naive Bayes models ready for estimation @@ -68,27 +106,3 @@ where self.fit_with(model_none, dataset) } } - -// Returns a subset of x corresponding to the class specified by `ycondition` -pub fn filter( - x: ArrayView2, - y: ArrayView1, - ycondition: &L, -) -> Array2 { - // We identify the row numbers corresponding to the class we are interested in - let index = y - .into_iter() - .enumerate() - .filter(|(_, y)| (*ycondition == **y)) - .map(|(i, _)| i) - .collect::>(); - - // We subset x to only records corresponding to the class represented in `ycondition` - let mut xsubset = Array2::zeros((index.len(), x.ncols())); - index - .into_iter() - .enumerate() - .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..]))); - - xsubset -} diff --git a/algorithms/linfa-bayes/src/bernoulli_nb.rs b/algorithms/linfa-bayes/src/bernoulli_nb.rs new file mode 100644 index 000000000..423abf0d3 --- /dev/null +++ b/algorithms/linfa-bayes/src/bernoulli_nb.rs @@ -0,0 +1,284 @@ +use linfa::dataset::{AsSingleTargets, DatasetBase, Labels}; +use linfa::traits::{Fit, FitWith, PredictInplace}; +use linfa::{Float, Label}; +use ndarray::{Array1, ArrayBase, ArrayView2, CowArray, Data, Ix2}; +use std::collections::HashMap; +use std::hash::Hash; + +use crate::base_nb::{NaiveBayes, NaiveBayesValidParams}; +use crate::error::{NaiveBayesError, Result}; +use crate::hyperparams::{BernoulliNbParams, BernoulliNbValidParams}; +use crate::{filter, ClassHistogram}; + +impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for BernoulliNbValidParams +where + F: Float, + L: Label + 'a, + D: Data, + T: AsSingleTargets + Labels, +{ +} + +impl Fit, T, NaiveBayesError> for BernoulliNbValidParams +where + F: Float, + L: Label + Ord, + D: Data, + T: AsSingleTargets + Labels, +{ + type Object = BernoulliNb; + + // Thin wrapper around the corresponding method of NaiveBayesValidParams + fn fit(&self, dataset: &DatasetBase, T>) -> Result { + NaiveBayesValidParams::fit(self, dataset, None) + } +} + +impl<'a, F, L, D, T> FitWith<'a, ArrayBase, T, NaiveBayesError> + for BernoulliNbValidParams +where + F: Float, + L: Label + 'a, + D: Data, + T: AsSingleTargets + Labels, +{ + type ObjectIn = Option>; + type ObjectOut = BernoulliNb; + + fn fit_with( + &self, + model_in: Self::ObjectIn, + dataset: &DatasetBase, T>, + ) -> Result { + let x = dataset.records(); + let y = dataset.as_single_targets(); + + let mut model = match model_in { + Some(temp) => temp, + None => BernoulliNb { + class_info: HashMap::new(), + binarize: self.binarize(), + }, + }; + + // Binarize data if the threshold is set + let xbin = model.binarize(x).to_owned(); + + // Calculate feature log probabilities + let yunique = dataset.labels(); + for class in yunique { + // We filter for records that correspond to the current class + let xclass = filter(xbin.view(), y.view(), &class); + + // We compute the feature log probabilities and feature counts on + // the slice corresponding to the current class + model + .class_info + .entry(class) + .or_insert_with(ClassHistogram::default) + .update_with_smoothing(xclass.view(), self.alpha(), true); + } + + // Update the priors + let class_count_sum = model + .class_info + .values() + .map(|x| x.class_count) + .sum::(); + + for info in model.class_info.values_mut() { + info.prior = F::cast(info.class_count) / F::cast(class_count_sum); + } + Ok(model) + } +} + +impl PredictInplace, Array1> for BernoulliNb +where + D: Data, +{ + // Thin wrapper around the corresponding method of NaiveBayes + fn predict_inplace(&self, x: &ArrayBase, y: &mut Array1) { + // Binarize data if the threshold is set + let xbin = self.binarize(x); + NaiveBayes::predict_inplace(self, &xbin, y); + } + + fn default_target(&self, x: &ArrayBase) -> Array1 { + Array1::default(x.nrows()) + } +} + +/// Fitted Bernoulli Naive Bayes classifier. +/// +/// See [BernoulliNbParams] for more information on the hyper-parameters. +/// +/// # Model assumptions +/// +/// The family of Naive Bayes classifiers assume independence between variables. They do not model +/// moments between variables and lack therefore in modelling capability. The advantage is a linear +/// fitting time with maximum-likelihood training in a closed form. +/// +/// # Model usage example +/// +/// The example below creates a set of hyperparameters, and then uses it to fit +/// a Bernoulli Naive Bayes classifier on provided data. +/// +/// ```rust +/// use linfa_bayes::{BernoulliNbParams, BernoulliNbValidParams, Result}; +/// use linfa::prelude::*; +/// use ndarray::array; +/// +/// let x = array![ +/// [-2., -1.], +/// [-1., -1.], +/// [-1., -2.], +/// [1., 1.], +/// [1., 2.], +/// [2., 1.] +/// ]; +/// let y = array![1, 1, 1, 2, 2, 2]; +/// let ds = DatasetView::new(x.view(), y.view()); +/// +/// // create a new parameter set with smoothing parameter equals `1` +/// let unchecked_params = BernoulliNbParams::new() +/// .alpha(1.0); +/// +/// // fit model with unchecked parameter set +/// let model = unchecked_params.fit(&ds)?; +/// +/// // transform into a verified parameter set +/// let checked_params = unchecked_params.check()?; +/// +/// // update model with the verified parameters, this only returns +/// // errors originating from the fitting process +/// let model = checked_params.fit_with(Some(model), &ds)?; +/// # Result::Ok(()) +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct BernoulliNb { + class_info: HashMap>, + binarize: Option, +} + +impl BernoulliNb { + /// Construct a new set of hyperparameters + pub fn params() -> BernoulliNbParams { + BernoulliNbParams::new() + } + + // Binarize data if the threshold is set + fn binarize<'a, D>(&'a self, x: &'a ArrayBase) -> CowArray<'a, F, Ix2> + where + D: Data, + { + if let Some(thr) = self.binarize { + let xbin = x.map(|v| if v > &thr { F::one() } else { F::zero() }); + CowArray::from(xbin) + } else { + CowArray::from(x) + } + } +} + +impl NaiveBayes<'_, F, L> for BernoulliNb +where + F: Float, + L: Label + Ord, +{ + // Compute unnormalized posterior log probability + fn joint_log_likelihood(&self, x: ArrayView2) -> HashMap<&L, Array1> { + let mut joint_log_likelihood = HashMap::new(); + for (class, info) in self.class_info.iter() { + // Combine feature log probabilities, their negatives, and class priors to + // get log-likelihood for each class + let neg_prob = info.feature_log_prob.map(|lp| (F::one() - lp.exp()).ln()); + let feature_log_prob = &info.feature_log_prob - &neg_prob; + let jll = x.dot(&feature_log_prob); + joint_log_likelihood.insert(class, jll + info.prior.ln() + neg_prob.sum()); + } + + joint_log_likelihood + } +} + +#[cfg(test)] +mod tests { + use super::{BernoulliNb, NaiveBayes, Result}; + use linfa::{ + traits::{Fit, Predict}, + DatasetView, + }; + + use crate::{BernoulliNbParams, BernoulliNbValidParams}; + use approx::assert_abs_diff_eq; + use ndarray::array; + use std::collections::HashMap; + + #[test] + fn autotraits() { + fn has_autotraits() {} + has_autotraits::>(); + has_autotraits::>(); + has_autotraits::>(); + } + + #[test] + fn test_bernoulli_nb() -> Result<()> { + let x = array![[1., 0.], [0., 0.], [1., 1.], [0., 1.]]; + let y = array![1, 1, 2, 2]; + let data = DatasetView::new(x.view(), y.view()); + + let params = BernoulliNb::params().binarize(None); + let fitted_clf = params.fit(&data)?; + assert!(&fitted_clf.binarize.is_none()); + + let pred = fitted_clf.predict(&x); + assert_abs_diff_eq!(pred, y); + + let jll = fitted_clf.joint_log_likelihood(x.view()); + let mut expected = HashMap::new(); + expected.insert( + &1usize, + (array![0.1875f64, 0.1875, 0.0625, 0.0625]).map(|v| v.ln()), + ); + + expected.insert( + &2usize, + (array![0.0625f64, 0.0625, 0.1875, 0.1875,]).map(|v| v.ln()), + ); + + for (key, value) in jll.iter() { + assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6); + } + + Ok(()) + } + + #[test] + fn test_text_class() -> Result<()> { + // From https://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html#tab:nbtoy + let train = array![ + // C, B, S, M, T, J + [2., 1., 0., 0., 0., 0.0f64], + [2., 0., 1., 0., 0., 0.], + [1., 0., 0., 1., 0., 0.], + [1., 0., 0., 0., 1., 1.], + ]; + let y = array![1, 1, 1, 2]; + let test = array![[3., 0., 0., 0., 1., 1.0f64]]; + + let data = DatasetView::new(train.view(), y.view()); + let fitted_clf = BernoulliNb::params().fit(&data)?; + let pred = fitted_clf.predict(&test); + + assert_abs_diff_eq!(pred, array![2]); + + // See: https://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html + let jll = fitted_clf.joint_log_likelihood(fitted_clf.binarize(&test).view()); + assert_abs_diff_eq!(jll.get(&1).unwrap()[0].exp(), 0.005, epsilon = 1e-3); + assert_abs_diff_eq!(jll.get(&2).unwrap()[0].exp(), 0.022, epsilon = 1e-3); + + Ok(()) + } +} diff --git a/algorithms/linfa-bayes/src/gaussian_nb.rs b/algorithms/linfa-bayes/src/gaussian_nb.rs index b89b55d3a..3dda882b3 100644 --- a/algorithms/linfa-bayes/src/gaussian_nb.rs +++ b/algorithms/linfa-bayes/src/gaussian_nb.rs @@ -6,8 +6,9 @@ use ndarray_stats::QuantileExt; use std::collections::HashMap; use std::hash::Hash; -use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams}; +use crate::base_nb::{NaiveBayes, NaiveBayesValidParams}; use crate::error::{NaiveBayesError, Result}; +use crate::filter; use crate::hyperparams::{GaussianNbParams, GaussianNbValidParams}; #[cfg(feature = "serde")] @@ -33,8 +34,7 @@ where // Thin wrapper around the corresponding method of NaiveBayesValidParams fn fit(&self, dataset: &DatasetBase, T>) -> Result { - let model = NaiveBayesValidParams::fit(self, dataset, None)?; - Ok(model.unwrap()) + NaiveBayesValidParams::fit(self, dataset, None) } } @@ -47,7 +47,7 @@ where T: AsSingleTargets + Labels, { type ObjectIn = Option>; - type ObjectOut = Option>; + type ObjectOut = GaussianNb; fn fit_with( &self, @@ -115,7 +115,7 @@ where info.prior = F::cast(info.class_count) / F::cast(class_count_sum); } - Ok(Some(model)) + Ok(model) } } @@ -295,7 +295,7 @@ mod tests { use super::{GaussianNb, NaiveBayes, Result}; use linfa::{ traits::{Fit, FitWith, Predict}, - DatasetView, + DatasetView, Error, }; use crate::gaussian_nb::GaussianClassInfo; @@ -334,6 +334,7 @@ mod tests { let jll = fitted_clf.joint_log_likelihood(x.view()); + // expected values from GaussianNB scikit-learn 1.6.1 let mut expected = HashMap::new(); expected.insert( &1usize, @@ -360,6 +361,27 @@ mod tests { assert_eq!(jll, expected); + let expected_proba = array![ + [1.00000000e+00, 2.31952358e-16], + [1.00000000e+00, 3.77513536e-11], + [1.00000000e+00, 2.31952358e-16], + [3.77513536e-11, 1.00000000e+00], + [2.31952358e-16, 1.00000000e+00], + [2.31952358e-16, 1.00000000e+00] + ]; + + let (y_pred_proba, classes) = fitted_clf.predict_proba(x.view()); + assert_eq!(classes, vec![&1usize, &2]); + assert_abs_diff_eq!(expected_proba, y_pred_proba, epsilon = 1e-10); + + let (y_pred_log_proba, classes) = fitted_clf.predict_log_proba(x.view()); + assert_eq!(classes, vec![&1usize, &2]); + assert_abs_diff_eq!( + y_pred_proba.mapv(f64::ln), + y_pred_log_proba, + epsilon = 1e-10 + ); + Ok(()) } @@ -381,8 +403,8 @@ mod tests { .axis_chunks_iter(Axis(0), 2) .zip(y.axis_chunks_iter(Axis(0), 2)) .map(|(a, b)| DatasetView::new(a, b)) - .fold(None, |current, d| clf.fit_with(current, &d).unwrap()) - .unwrap(); + .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))? + .ok_or(Error::NotEnoughSamples)?; let pred = model.predict(&x); diff --git a/algorithms/linfa-bayes/src/hyperparams.rs b/algorithms/linfa-bayes/src/hyperparams.rs index eda86665e..33fe5f550 100644 --- a/algorithms/linfa-bayes/src/hyperparams.rs +++ b/algorithms/linfa-bayes/src/hyperparams.rs @@ -188,3 +188,104 @@ impl ParamGuard for MultinomialNbParams { Ok(self.0) } } + +/// A verified hyper-parameter set ready for the estimation of a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb). +/// +/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model and [`BernoulliNbParams`](crate::hyperparams::BernoulliNbParams) for information on hyperparameters. +#[derive(Debug, Clone, PartialEq)] +pub struct BernoulliNbValidParams { + // Required for calculation stability + alpha: F, + // Threshold for binarization + binarize: Option, + // Phantom data for label type + label: PhantomData, +} + +impl BernoulliNbValidParams { + /// Get the variance smoothing + pub fn alpha(&self) -> F { + self.alpha + } + /// Get the binarization threshold + pub fn binarize(&self) -> Option { + self.binarize + } +} + +/// A hyper-parameter set during construction for a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb). +/// +/// The parameter set can be verified into a +/// [`BernoulliNbValidParams`](crate::hyperparams::BernoulliNbValidParams) by calling +/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with +/// [Fit::fit](linfa::traits::Fit::fit) or +/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitly verifies the parameter set +/// prior to the model estimation and forwards any error. +/// +/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model. +/// +/// # Parameters +/// | Name | Default | Purpose | Range | +/// | :--- | :--- | :---| :--- | +/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` | +/// | [binarize](Self::binarize) | `0.0` | Threshold for binarization (mapping to booleans) of sample features. If `None`, input is presumed to already consist of binary vectors. | `(-inf, inf)` | +/// +/// # Errors +/// +/// The following errors can come from invalid hyper-parameters: +/// +/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing +/// parameter is negative. +/// +#[derive(Debug, Clone, PartialEq)] +pub struct BernoulliNbParams(BernoulliNbValidParams); + +impl Default for BernoulliNbParams { + fn default() -> Self { + Self::new() + } +} + +impl BernoulliNbParams { + /// Create new [BernoulliNbParams] set with default values for its parameters + pub fn new() -> Self { + Self(BernoulliNbValidParams { + alpha: F::one(), + binarize: Some(F::zero()), + label: PhantomData, + }) + } + + /// Specifies the portion of the largest variance of all the features that + /// is added to the variance for calculation stability + pub fn alpha(mut self, alpha: F) -> Self { + self.0.alpha = alpha; + self + } + + /// Set the binarization threshold + pub fn binarize(mut self, threshold: Option) -> Self { + self.0.binarize = threshold; + self + } +} + +impl ParamGuard for BernoulliNbParams { + type Checked = BernoulliNbValidParams; + type Error = NaiveBayesError; + + fn check_ref(&self) -> Result<&Self::Checked, Self::Error> { + if self.0.alpha.is_negative() { + Err(NaiveBayesError::InvalidSmoothing( + self.0.alpha.to_f64().unwrap(), + )) + } else { + Ok(&self.0) + } + } + + fn check(self) -> Result { + self.check_ref()?; + Ok(self.0) + } +} diff --git a/algorithms/linfa-bayes/src/lib.rs b/algorithms/linfa-bayes/src/lib.rs index eb74535ab..273a1e955 100644 --- a/algorithms/linfa-bayes/src/lib.rs +++ b/algorithms/linfa-bayes/src/lib.rs @@ -1,13 +1,104 @@ #![doc = include_str!("../README.md")] mod base_nb; +mod bernoulli_nb; mod error; mod gaussian_nb; mod hyperparams; mod multinomial_nb; +pub use base_nb::NaiveBayes; +pub use bernoulli_nb::BernoulliNb; pub use error::{NaiveBayesError, Result}; pub use gaussian_nb::GaussianNb; +pub use hyperparams::{BernoulliNbParams, BernoulliNbValidParams}; pub use hyperparams::{GaussianNbParams, GaussianNbValidParams}; pub use hyperparams::{MultinomialNbParams, MultinomialNbValidParams}; pub use multinomial_nb::MultinomialNb; + +use linfa::{Float, Label}; +use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; + +#[cfg(feature = "serde")] +use serde_crate::{Deserialize, Serialize}; + +/// Histogram of class occurrences for multinomial and binomial parameter estimation +#[derive(Debug, Default, Clone, PartialEq)] +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +pub(crate) struct ClassHistogram { + class_count: usize, + prior: F, + feature_count: Array1, + feature_log_prob: Array1, +} + +impl ClassHistogram { + // Update log probabilities of features given class + fn update_with_smoothing(&mut self, x_new: ArrayView2, alpha: F, total_count: bool) { + // If incoming data is empty no updates required + if x_new.nrows() == 0 { + return; + } + + // unpack old class information + let ClassHistogram { + class_count, + feature_count, + feature_log_prob, + .. + } = self; + + // count new feature occurrences + let feature_count_new: Array1 = x_new.sum_axis(Axis(0)); + + // if previous batch was empty, we send the new feature count calculated + if *class_count > 0 { + *feature_count = feature_count_new + feature_count.view(); + } else { + *feature_count = feature_count_new; + } + + // apply smoothing to feature counts + let feature_count_smoothed = feature_count.mapv(|x| x + alpha); + + // compute total count (smoothed) + let count = if total_count { + F::cast(x_new.nrows()) + alpha * F::cast(2) + } else { + feature_count_smoothed.sum() + }; + + // compute log probabilities of each feature + *feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - count.ln()); + // update class count + *class_count += x_new.nrows(); + } +} + +/// Returns a subset of x corresponding to the class specified by `ycondition` +pub(crate) fn filter( + x: ArrayView2, + y: ArrayView1, + ycondition: &L, +) -> Array2 { + // We identify the row numbers corresponding to the class we are interested in + let index = y + .into_iter() + .enumerate() + .filter(|&(_, y)| (*ycondition == *y)) + .map(|(i, _)| i) + .collect::>(); + + // We subset x to only records corresponding to the class represented in `ycondition` + let mut xsubset = Array2::zeros((index.len(), x.ncols())); + index + .into_iter() + .enumerate() + .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..]))); + + xsubset +} diff --git a/algorithms/linfa-bayes/src/multinomial_nb.rs b/algorithms/linfa-bayes/src/multinomial_nb.rs index 1fc852d26..c06733e0d 100644 --- a/algorithms/linfa-bayes/src/multinomial_nb.rs +++ b/algorithms/linfa-bayes/src/multinomial_nb.rs @@ -1,13 +1,14 @@ use linfa::dataset::{AsSingleTargets, DatasetBase, Labels}; use linfa::traits::{Fit, FitWith, PredictInplace}; use linfa::{Float, Label}; -use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2}; +use ndarray::{Array1, ArrayBase, ArrayView2, Data, Ix2}; use std::collections::HashMap; use std::hash::Hash; -use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams}; +use crate::base_nb::{NaiveBayes, NaiveBayesValidParams}; use crate::error::{NaiveBayesError, Result}; use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams}; +use crate::{filter, ClassHistogram}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; @@ -31,8 +32,7 @@ where type Object = MultinomialNb; // Thin wrapper around the corresponding method of NaiveBayesValidParams fn fit(&self, dataset: &DatasetBase, T>) -> Result { - let model = NaiveBayesValidParams::fit(self, dataset, None)?; - Ok(model.unwrap()) + NaiveBayesValidParams::fit(self, dataset, None) } } @@ -45,7 +45,7 @@ where T: AsSingleTargets + Labels, { type ObjectIn = Option>; - type ObjectOut = Option>; + type ObjectOut = MultinomialNb; fn fit_with( &self, @@ -65,34 +65,31 @@ where let yunique = dataset.labels(); for class in yunique { - // We filter for records that correspond to the current class + // filter dataset for current class let xclass = filter(x.view(), y.view(), &class); - // We count the number of occurences of the class - let nclass = xclass.nrows(); - // We compute the feature log probabilities and feature counts on the slice corresponding to the current class - let class_info = model + // compute feature log probabilities and counts + model .class_info - .entry(class) - .or_insert_with(MultinomialClassInfo::default); - let (feature_log_prob, feature_count) = - self.update_feature_log_prob(class_info, xclass.view()); - // We now update the total counts of each feature, feature log probabilities, and class count - class_info.feature_log_prob = feature_log_prob; - class_info.feature_count = feature_count; - class_info.class_count += nclass; + .entry(class.clone()) + .or_insert_with(ClassHistogram::default) + .update_with_smoothing(xclass.view(), self.alpha(), false); + + dbg!(&model.class_info.get(&class)); } - // We update the priors + // update priors let class_count_sum = model .class_info .values() .map(|x| x.class_count) .sum::(); + for info in model.class_info.values_mut() { info.prior = F::cast(info.class_count) / F::cast(class_count_sum); } - Ok(Some(model)) + + Ok(model) } } @@ -110,49 +107,6 @@ where } } -impl MultinomialNbValidParams -where - F: Float, -{ - // Update log probabilities of features given class - fn update_feature_log_prob( - &self, - info_old: &MultinomialClassInfo, - x_new: ArrayView2, - ) -> (Array1, Array1) { - // Deconstruct old state - let (count_old, feature_log_prob_old, feature_count_old) = ( - &info_old.class_count, - &info_old.feature_log_prob, - &info_old.feature_count, - ); - - // If incoming data is empty no updates required - if x_new.nrows() == 0 { - return ( - feature_log_prob_old.to_owned(), - feature_count_old.to_owned(), - ); - } - - let feature_count_new = x_new.sum_axis(Axis(0)); - - // If previous batch was empty, we send the new feature count calculated - let feature_count = if count_old > &0 { - feature_count_old + feature_count_new - } else { - feature_count_new - }; - // Apply smoothing to feature counts - let feature_count_smoothed = feature_count.clone() + self.alpha(); - // Compute total count over all (smoothed) features - let count = feature_count_smoothed.sum(); - // Compute log probabilities of each feature - let feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - F::cast(count).ln()); - (feature_log_prob.to_owned(), feature_count.to_owned()) - } -} - /// Fitted Multinomial Naive Bayes classifier. /// /// See [MultinomialNbParams] for more information on the hyper-parameters. @@ -206,20 +160,7 @@ where )] #[derive(Debug, Clone, PartialEq)] pub struct MultinomialNb { - class_info: HashMap>, -} - -#[cfg_attr( - feature = "serde", - derive(Serialize, Deserialize), - serde(crate = "serde_crate") -)] -#[derive(Debug, Default, Clone, PartialEq)] -struct MultinomialClassInfo { - class_count: usize, - prior: F, - feature_count: Array1, - feature_log_prob: Array1, + class_info: HashMap>, } impl MultinomialNb { @@ -253,10 +194,9 @@ mod tests { use super::{MultinomialNb, NaiveBayes, Result}; use linfa::{ traits::{Fit, FitWith, Predict}, - DatasetView, + Dataset, DatasetView, Error, }; - use crate::multinomial_nb::MultinomialClassInfo; use crate::{MultinomialNbParams, MultinomialNbValidParams}; use approx::assert_abs_diff_eq; use ndarray::{array, Axis}; @@ -266,23 +206,23 @@ mod tests { fn autotraits() { fn has_autotraits() {} has_autotraits::>(); - has_autotraits::>(); has_autotraits::>(); has_autotraits::>(); } #[test] fn test_multinomial_nb() -> Result<()> { - let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]]; - let y = array![1, 1, 1, 2, 2, 2]; + let ds = Dataset::new( + array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]], + array![1, 1, 1, 2, 2, 2], + ); - let data = DatasetView::new(x.view(), y.view()); - let fitted_clf = MultinomialNb::params().fit(&data)?; - let pred = fitted_clf.predict(&x); + let fitted_clf = MultinomialNb::params().fit(&ds)?; + let pred = fitted_clf.predict(ds.records()); - assert_abs_diff_eq!(pred, y); + assert_abs_diff_eq!(pred, ds.targets()); - let jll = fitted_clf.joint_log_likelihood(x.view()); + let jll = fitted_clf.joint_log_likelihood(ds.records().view()); let mut expected = HashMap::new(); // Computed with sklearn.naive_bayes.MultinomialNB expected.insert( @@ -327,8 +267,8 @@ mod tests { .axis_chunks_iter(Axis(0), 2) .zip(y.axis_chunks_iter(Axis(0), 2)) .map(|(a, b)| DatasetView::new(a, b)) - .fold(None, |current, d| clf.fit_with(current, &d).unwrap()) - .unwrap(); + .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))? + .ok_or(Error::NotEnoughSamples)?; let pred = model.predict(&x); diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index 846aea8d9..b537af64f 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -46,7 +46,7 @@ use serde_crate::{Deserialize, Serialize}; /// There are three steps in the standard algorithm: /// - initialisation step: select initial centroids using one of our provided algorithms. /// - assignment step: assign each observation to the nearest cluster -/// (minimum distance between the observation and the cluster's centroid); +/// (minimum distance between the observation and the cluster's centroid); /// - update step: recompute the centroid of each cluster. /// /// The initialisation step is a one-off, done at the very beginning. diff --git a/algorithms/linfa-clustering/src/lib.rs b/algorithms/linfa-clustering/src/lib.rs index a9aa72858..7418bc64a 100644 --- a/algorithms/linfa-clustering/src/lib.rs +++ b/algorithms/linfa-clustering/src/lib.rs @@ -16,7 +16,7 @@ //! * [K-Means](KMeans) //! * [DBSCAN](Dbscan) //! * [Approximated DBSCAN](AppxDbscan) (Currently an alias for DBSCAN, due to its superior -//! performance) +//! performance) //! * [Gaussian-Mixture-Model](GaussianMixtureModel) //! * [OPTICS](OpticsAnalysis) //! diff --git a/algorithms/linfa-kernel/src/lib.rs b/algorithms/linfa-kernel/src/lib.rs index b2361ab8c..a2bb37997 100644 --- a/algorithms/linfa-kernel/src/lib.rs +++ b/algorithms/linfa-kernel/src/lib.rs @@ -398,7 +398,7 @@ impl /// /// A new dataset with: /// - records: a kernel build from `x.records()` according to the parameters on which - /// this method is called + /// this method is called /// - targets: same as `x.targets()` /// /// ## Panics @@ -425,7 +425,7 @@ impl<'a, F: Float, L: 'a, T: AsTargets + FromTargetArray<'a>, N: Neare /// /// A new dataset with: /// - records: a kernel build from `x.records()` according to the parameters on which - /// this method is called + /// this method is called /// - targets: same as `x.targets()` /// /// ## Panics @@ -460,7 +460,7 @@ impl< /// /// A new dataset with: /// - records: a kernel build from `x.records()` according to the parameters on which - /// this method is called + /// this method is called /// - targets: a slice of `x.targets()` /// /// ## Panics diff --git a/algorithms/linfa-linear/src/ols.rs b/algorithms/linfa-linear/src/ols.rs index 25968719d..3f915c0cc 100644 --- a/algorithms/linfa-linear/src/ols.rs +++ b/algorithms/linfa-linear/src/ols.rs @@ -243,9 +243,9 @@ mod tests { /// We can't fit a line through three points in general /// - in this case we should find the solution that minimizes - /// the squares. Fitting a line with intercept through the - /// points (0, 0), (1, 0), (2, 2) has the least-squares solution - /// f(x) = -1./3. + x + /// the squares. Fitting a line with intercept through the + /// points (0, 0), (1, 0), (2, 2) has the least-squares solution + /// f(x) = -1./3. + x #[test] fn fits_least_squares_line_through_three_dots() { let lin_reg = LinearRegression::new(); diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs index e051e1722..1d0522b55 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs @@ -51,11 +51,11 @@ impl SerdeRegex { /// they will still be used by the [CountVectorizer](crate::CountVectorizer) to transform any text to be examined. /// /// * `split_regex`: the regex espression used to split decuments into tokens. Defaults to r"\\b\\w\\w+\\b", which selects "words", using whitespaces and -/// punctuation symbols as separators. +/// punctuation symbols as separators. /// * `convert_to_lowercase`: if true, all documents used for fitting will be converted to lowercase. Defaults to `true`. /// * `n_gram_range`: if set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered, -/// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the -/// regex used fpr splitting the documents. The default value is `(1,1)`. +/// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the +/// regex used fpr splitting the documents. The default value is `(1,1)`. /// * `normalize`: if true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization. Defaults to `true`. /// * `document_frequency`: specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy. Defaults to `(0., 1.)` (i.e. 0% minimum and 100% maximum) /// * `stopwords`: optional list of entries to be excluded from the generated vocabulary. Defaults to `None` diff --git a/algorithms/linfa-trees/benches/decision_tree.rs b/algorithms/linfa-trees/benches/decision_tree.rs index fdaeeb0c9..075149402 100644 --- a/algorithms/linfa-trees/benches/decision_tree.rs +++ b/algorithms/linfa-trees/benches/decision_tree.rs @@ -39,6 +39,7 @@ fn decision_tree_bench(c: &mut Criterion) { Array2::random_using((n_classes, n_features), Uniform::new(-30., 30.), &mut rng); let train_x = generate_blobs(¢roids, *n, &mut rng); + #[allow(clippy::manual_repeat_n)] let train_y: Array1 = (0..n_classes) .flat_map(|x| std::iter::repeat(x).take(*n).collect::>()) .collect::>(); diff --git a/algorithms/linfa-trees/src/decision_trees/algorithm.rs b/algorithms/linfa-trees/src/decision_trees/algorithm.rs index 5c47fdcee..4ce54ebba 100644 --- a/algorithms/linfa-trees/src/decision_trees/algorithm.rs +++ b/algorithms/linfa-trees/src/decision_trees/algorithm.rs @@ -431,7 +431,7 @@ impl TreeNode { /// ### Structure /// A decision tree structure is a binary tree where: /// * Each internal node specifies a decision, represented by a choice of a feature and a "split value" such that all observations for which -/// `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree. +/// `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree. /// /// * leaf nodes make predictions, and their prediction is the most popular label in the node /// diff --git a/src/correlation.rs b/src/correlation.rs index 4962e787b..31a01c1a7 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -128,7 +128,7 @@ impl PearsonCorrelation { /// /// * `dataset`: Data for the correlation analysis /// * `num_iter`: optionally number of iterations of the p-value test, if none then no p-value - /// are calculate + /// are calculate /// /// # Example /// diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index 0d019174b..a43944057 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -712,8 +712,8 @@ where /// - `k`: the number of folds to apply to the dataset /// - `params`: the desired parameters for the fittable algorithm at hand /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model` - /// that will be used to produce the trained model for each fold. The training data given in input - /// won't outlive the closure. + /// that will be used to produce the trained model for each fold. The training data given in input + /// won't outlive the closure. /// /// ## Returns /// @@ -826,9 +826,9 @@ where /// - `k`: the number of folds to apply /// - `parameters`: a list of models to compare /// - `eval`: closure used to evaluate the performance of each trained model. This closure is - /// called on the model output and validation targets of each fold and outputs the performance - /// score for each target. For single-target dataset the signature is `(Array1, Array1) -> - /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`. + /// called on the model output and validation targets of each fold and outputs the performance + /// score for each target. For single-target dataset the signature is `(Array1, Array1) -> + /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`. /// /// ### Returns /// diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 21d50aeff..2f275a553 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -161,7 +161,7 @@ impl Deref for Pr { /// # Fields /// /// * `records`: a two-dimensional matrix with dimensionality (nsamples, nfeatures), in case of -/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse +/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse /// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets) /// * `weights`: optional weights for each sample with dimensionality (nsamples) /// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures) @@ -171,7 +171,7 @@ impl Deref for Pr { /// /// * `R: Records`: generic over feature matrices or kernel matrices /// * `T`: generic over any `ndarray` matrix which can be used as targets. The `AsTargets` trait -/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs` +/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs` #[derive(Debug, Clone, PartialEq)] pub struct DatasetBase where