diff --git a/.github/workflows/checking.yml b/.github/workflows/checking.yml
index 04243116b..001be4382 100644
--- a/.github/workflows/checking.yml
+++ b/.github/workflows/checking.yml
@@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- - 1.81.0
+ - 1.82.0
- stable
- nightly
os:
diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml
index 2eefebb5f..45b633773 100644
--- a/.github/workflows/testing.yml
+++ b/.github/workflows/testing.yml
@@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- - 1.81.0
+ - 1.82.0
- stable
os:
- ubuntu-latest
@@ -35,7 +35,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- - 1.81.0
+ - 1.82.0
- stable
os:
- ubuntu-latest
diff --git a/algorithms/linfa-bayes/README.md b/algorithms/linfa-bayes/README.md
index 2fc57679d..ba11e116a 100644
--- a/algorithms/linfa-bayes/README.md
+++ b/algorithms/linfa-bayes/README.md
@@ -11,7 +11,8 @@
`linfa-bayes` currently provides an implementation of the following methods:
- Gaussian Naive Bayes ([`GaussianNb`])
-- Multinomial Naive Nayes ([`MultinomialNb`]))
+- Multinomial Naive Nayes ([`MultinomialNb`])
+- Bernoulli Naive Nayes ([`BernoulliNb`])
## Examples
@@ -95,3 +96,43 @@ println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
# Result::Ok(())
```
+
+To run Bernoulli Naive Bayes example, use:
+
+```bash
+$ cargo run --example winequality_bernoulli --release
+```
+
+
+
+Show source code
+
+
+```rust, no_run
+use linfa::metrics::ToConfusionMatrix;
+use linfa::traits::{Fit, Predict};
+use linfa_bayes::{BernoulliNb, Result};
+
+// Read in the dataset and convert targets to binary data
+let (train, valid) = linfa_datasets::winequality()
+ .map_targets(|x| if *x > 6 { "good" } else { "bad" })
+ .split_with_ratio(0.9);
+
+// Train the model
+let model = BernoulliNb::params().fit(&train)?;
+
+// Predict the validation dataset
+let pred = model.predict(&valid);
+
+// Construct confusion matrix
+let cm = pred.confusion_matrix(&valid)?;
+// classes | bad | good
+// bad | 142 | 0
+// good | 17 | 0
+
+// accuracy 0.8930818, MCC NaN
+println!("{:?}", cm);
+println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
+# Result::Ok(())
+```
+
diff --git a/algorithms/linfa-bayes/examples/winequality_bernouilli.rs b/algorithms/linfa-bayes/examples/winequality_bernouilli.rs
new file mode 100644
index 000000000..53790726c
--- /dev/null
+++ b/algorithms/linfa-bayes/examples/winequality_bernouilli.rs
@@ -0,0 +1,28 @@
+use linfa::metrics::ToConfusionMatrix;
+use linfa::traits::{Fit, Predict};
+use linfa_bayes::{BernoulliNb, Result};
+
+fn main() -> Result<()> {
+ // Read in the dataset and convert targets to binary data
+ let (train, valid) = linfa_datasets::winequality()
+ .map_targets(|x| if *x > 6 { "good" } else { "bad" })
+ .split_with_ratio(0.9);
+
+ // Train the model
+ let model = BernoulliNb::params().fit(&train)?;
+
+ // Predict the validation dataset
+ let pred = model.predict(&valid);
+
+ // Construct confusion matrix
+ let cm = pred.confusion_matrix(&valid)?;
+ // classes | bad | good
+ // bad | 142 | 0
+ // good | 17 | 0
+
+ // accuracy 0.8930818, MCC
+ println!("{:?}", cm);
+ println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
+
+ Ok(())
+}
diff --git a/algorithms/linfa-bayes/src/base_nb.rs b/algorithms/linfa-bayes/src/base_nb.rs
index d54fdc44b..f0b87ce2f 100644
--- a/algorithms/linfa-bayes/src/base_nb.rs
+++ b/algorithms/linfa-bayes/src/base_nb.rs
@@ -1,4 +1,4 @@
-use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
+use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
use ndarray_stats::QuantileExt;
use std::collections::HashMap;
@@ -8,13 +8,17 @@ use linfa::traits::FitWith;
use linfa::{Float, Label};
// Trait computing predictions for fitted Naive Bayes models
-pub(crate) trait NaiveBayes<'a, F, L>
+pub trait NaiveBayes<'a, F, L>
where
F: Float,
L: Label + Ord,
{
+ /// Compute the unnormalized posterior log probabilities.
+ /// The result is returned as an HashMap indexing log probabilities for each samples (eg x rows) by classes
+ /// (eg jll\[class\] -> (n_samples,) array)
fn joint_log_likelihood(&self, x: ArrayView2) -> HashMap<&L, Array1>;
+ #[doc(hidden)]
fn predict_inplace>(&self, x: &ArrayBase, y: &mut Array1) {
assert_eq!(
x.nrows(),
@@ -45,6 +49,40 @@ where
classes[i].clone()
});
}
+
+ /// Compute log-probability estimates for each sample wrt classes.
+ /// The columns corresponds to classes in sorted order returned as the second output.
+ fn predict_log_proba(&self, x: ArrayView2) -> (Array2, Vec<&L>) {
+ let log_likelihood = self.joint_log_likelihood(x);
+
+ let mut classes = log_likelihood.keys().cloned().collect::>();
+ classes.sort();
+
+ let n_samples = x.nrows();
+ let n_classes = log_likelihood.len();
+ let mut log_prob_mat = Array2::::zeros((n_samples, n_classes));
+
+ Zip::from(log_prob_mat.columns_mut())
+ .and(&classes)
+ .for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));
+
+ let log_prob_x = log_prob_mat
+ .mapv(|x| x.exp())
+ .sum_axis(Axis(1))
+ .mapv(|x| x.ln())
+ .into_shape((n_samples, 1))
+ .unwrap();
+
+ (log_prob_mat - log_prob_x, classes)
+ }
+
+ /// Compute probability estimates for each sample wrt classes.
+ /// The columns corresponds to classes in sorted order returned as the second output.
+ fn predict_proba(&self, x: ArrayView2) -> (Array2, Vec<&L>) {
+ let (log_prob_mat, classes) = self.predict_log_proba(x);
+
+ (log_prob_mat.mapv(|v| v.exp()), classes)
+ }
}
// Common functionality for hyper-parameter sets of Naive Bayes models ready for estimation
@@ -68,27 +106,3 @@ where
self.fit_with(model_none, dataset)
}
}
-
-// Returns a subset of x corresponding to the class specified by `ycondition`
-pub fn filter(
- x: ArrayView2,
- y: ArrayView1,
- ycondition: &L,
-) -> Array2 {
- // We identify the row numbers corresponding to the class we are interested in
- let index = y
- .into_iter()
- .enumerate()
- .filter(|(_, y)| (*ycondition == **y))
- .map(|(i, _)| i)
- .collect::>();
-
- // We subset x to only records corresponding to the class represented in `ycondition`
- let mut xsubset = Array2::zeros((index.len(), x.ncols()));
- index
- .into_iter()
- .enumerate()
- .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
-
- xsubset
-}
diff --git a/algorithms/linfa-bayes/src/bernoulli_nb.rs b/algorithms/linfa-bayes/src/bernoulli_nb.rs
new file mode 100644
index 000000000..423abf0d3
--- /dev/null
+++ b/algorithms/linfa-bayes/src/bernoulli_nb.rs
@@ -0,0 +1,284 @@
+use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
+use linfa::traits::{Fit, FitWith, PredictInplace};
+use linfa::{Float, Label};
+use ndarray::{Array1, ArrayBase, ArrayView2, CowArray, Data, Ix2};
+use std::collections::HashMap;
+use std::hash::Hash;
+
+use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
+use crate::error::{NaiveBayesError, Result};
+use crate::hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
+use crate::{filter, ClassHistogram};
+
+impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for BernoulliNbValidParams
+where
+ F: Float,
+ L: Label + 'a,
+ D: Data,
+ T: AsSingleTargets + Labels,
+{
+}
+
+impl Fit, T, NaiveBayesError> for BernoulliNbValidParams
+where
+ F: Float,
+ L: Label + Ord,
+ D: Data,
+ T: AsSingleTargets + Labels,
+{
+ type Object = BernoulliNb;
+
+ // Thin wrapper around the corresponding method of NaiveBayesValidParams
+ fn fit(&self, dataset: &DatasetBase, T>) -> Result {
+ NaiveBayesValidParams::fit(self, dataset, None)
+ }
+}
+
+impl<'a, F, L, D, T> FitWith<'a, ArrayBase, T, NaiveBayesError>
+ for BernoulliNbValidParams
+where
+ F: Float,
+ L: Label + 'a,
+ D: Data,
+ T: AsSingleTargets + Labels,
+{
+ type ObjectIn = Option>;
+ type ObjectOut = BernoulliNb;
+
+ fn fit_with(
+ &self,
+ model_in: Self::ObjectIn,
+ dataset: &DatasetBase, T>,
+ ) -> Result {
+ let x = dataset.records();
+ let y = dataset.as_single_targets();
+
+ let mut model = match model_in {
+ Some(temp) => temp,
+ None => BernoulliNb {
+ class_info: HashMap::new(),
+ binarize: self.binarize(),
+ },
+ };
+
+ // Binarize data if the threshold is set
+ let xbin = model.binarize(x).to_owned();
+
+ // Calculate feature log probabilities
+ let yunique = dataset.labels();
+ for class in yunique {
+ // We filter for records that correspond to the current class
+ let xclass = filter(xbin.view(), y.view(), &class);
+
+ // We compute the feature log probabilities and feature counts on
+ // the slice corresponding to the current class
+ model
+ .class_info
+ .entry(class)
+ .or_insert_with(ClassHistogram::default)
+ .update_with_smoothing(xclass.view(), self.alpha(), true);
+ }
+
+ // Update the priors
+ let class_count_sum = model
+ .class_info
+ .values()
+ .map(|x| x.class_count)
+ .sum::();
+
+ for info in model.class_info.values_mut() {
+ info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
+ }
+ Ok(model)
+ }
+}
+
+impl PredictInplace, Array1> for BernoulliNb
+where
+ D: Data,
+{
+ // Thin wrapper around the corresponding method of NaiveBayes
+ fn predict_inplace(&self, x: &ArrayBase, y: &mut Array1) {
+ // Binarize data if the threshold is set
+ let xbin = self.binarize(x);
+ NaiveBayes::predict_inplace(self, &xbin, y);
+ }
+
+ fn default_target(&self, x: &ArrayBase) -> Array1 {
+ Array1::default(x.nrows())
+ }
+}
+
+/// Fitted Bernoulli Naive Bayes classifier.
+///
+/// See [BernoulliNbParams] for more information on the hyper-parameters.
+///
+/// # Model assumptions
+///
+/// The family of Naive Bayes classifiers assume independence between variables. They do not model
+/// moments between variables and lack therefore in modelling capability. The advantage is a linear
+/// fitting time with maximum-likelihood training in a closed form.
+///
+/// # Model usage example
+///
+/// The example below creates a set of hyperparameters, and then uses it to fit
+/// a Bernoulli Naive Bayes classifier on provided data.
+///
+/// ```rust
+/// use linfa_bayes::{BernoulliNbParams, BernoulliNbValidParams, Result};
+/// use linfa::prelude::*;
+/// use ndarray::array;
+///
+/// let x = array![
+/// [-2., -1.],
+/// [-1., -1.],
+/// [-1., -2.],
+/// [1., 1.],
+/// [1., 2.],
+/// [2., 1.]
+/// ];
+/// let y = array![1, 1, 1, 2, 2, 2];
+/// let ds = DatasetView::new(x.view(), y.view());
+///
+/// // create a new parameter set with smoothing parameter equals `1`
+/// let unchecked_params = BernoulliNbParams::new()
+/// .alpha(1.0);
+///
+/// // fit model with unchecked parameter set
+/// let model = unchecked_params.fit(&ds)?;
+///
+/// // transform into a verified parameter set
+/// let checked_params = unchecked_params.check()?;
+///
+/// // update model with the verified parameters, this only returns
+/// // errors originating from the fitting process
+/// let model = checked_params.fit_with(Some(model), &ds)?;
+/// # Result::Ok(())
+/// ```
+#[derive(Debug, Clone, PartialEq)]
+pub struct BernoulliNb {
+ class_info: HashMap>,
+ binarize: Option,
+}
+
+impl BernoulliNb {
+ /// Construct a new set of hyperparameters
+ pub fn params() -> BernoulliNbParams {
+ BernoulliNbParams::new()
+ }
+
+ // Binarize data if the threshold is set
+ fn binarize<'a, D>(&'a self, x: &'a ArrayBase) -> CowArray<'a, F, Ix2>
+ where
+ D: Data,
+ {
+ if let Some(thr) = self.binarize {
+ let xbin = x.map(|v| if v > &thr { F::one() } else { F::zero() });
+ CowArray::from(xbin)
+ } else {
+ CowArray::from(x)
+ }
+ }
+}
+
+impl NaiveBayes<'_, F, L> for BernoulliNb
+where
+ F: Float,
+ L: Label + Ord,
+{
+ // Compute unnormalized posterior log probability
+ fn joint_log_likelihood(&self, x: ArrayView2) -> HashMap<&L, Array1> {
+ let mut joint_log_likelihood = HashMap::new();
+ for (class, info) in self.class_info.iter() {
+ // Combine feature log probabilities, their negatives, and class priors to
+ // get log-likelihood for each class
+ let neg_prob = info.feature_log_prob.map(|lp| (F::one() - lp.exp()).ln());
+ let feature_log_prob = &info.feature_log_prob - &neg_prob;
+ let jll = x.dot(&feature_log_prob);
+ joint_log_likelihood.insert(class, jll + info.prior.ln() + neg_prob.sum());
+ }
+
+ joint_log_likelihood
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{BernoulliNb, NaiveBayes, Result};
+ use linfa::{
+ traits::{Fit, Predict},
+ DatasetView,
+ };
+
+ use crate::{BernoulliNbParams, BernoulliNbValidParams};
+ use approx::assert_abs_diff_eq;
+ use ndarray::array;
+ use std::collections::HashMap;
+
+ #[test]
+ fn autotraits() {
+ fn has_autotraits() {}
+ has_autotraits::>();
+ has_autotraits::>();
+ has_autotraits::>();
+ }
+
+ #[test]
+ fn test_bernoulli_nb() -> Result<()> {
+ let x = array![[1., 0.], [0., 0.], [1., 1.], [0., 1.]];
+ let y = array![1, 1, 2, 2];
+ let data = DatasetView::new(x.view(), y.view());
+
+ let params = BernoulliNb::params().binarize(None);
+ let fitted_clf = params.fit(&data)?;
+ assert!(&fitted_clf.binarize.is_none());
+
+ let pred = fitted_clf.predict(&x);
+ assert_abs_diff_eq!(pred, y);
+
+ let jll = fitted_clf.joint_log_likelihood(x.view());
+ let mut expected = HashMap::new();
+ expected.insert(
+ &1usize,
+ (array![0.1875f64, 0.1875, 0.0625, 0.0625]).map(|v| v.ln()),
+ );
+
+ expected.insert(
+ &2usize,
+ (array![0.0625f64, 0.0625, 0.1875, 0.1875,]).map(|v| v.ln()),
+ );
+
+ for (key, value) in jll.iter() {
+ assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_text_class() -> Result<()> {
+ // From https://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html#tab:nbtoy
+ let train = array![
+ // C, B, S, M, T, J
+ [2., 1., 0., 0., 0., 0.0f64],
+ [2., 0., 1., 0., 0., 0.],
+ [1., 0., 0., 1., 0., 0.],
+ [1., 0., 0., 0., 1., 1.],
+ ];
+ let y = array![1, 1, 1, 2];
+ let test = array![[3., 0., 0., 0., 1., 1.0f64]];
+
+ let data = DatasetView::new(train.view(), y.view());
+ let fitted_clf = BernoulliNb::params().fit(&data)?;
+ let pred = fitted_clf.predict(&test);
+
+ assert_abs_diff_eq!(pred, array![2]);
+
+ // See: https://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
+ let jll = fitted_clf.joint_log_likelihood(fitted_clf.binarize(&test).view());
+ assert_abs_diff_eq!(jll.get(&1).unwrap()[0].exp(), 0.005, epsilon = 1e-3);
+ assert_abs_diff_eq!(jll.get(&2).unwrap()[0].exp(), 0.022, epsilon = 1e-3);
+
+ Ok(())
+ }
+}
diff --git a/algorithms/linfa-bayes/src/gaussian_nb.rs b/algorithms/linfa-bayes/src/gaussian_nb.rs
index b89b55d3a..3dda882b3 100644
--- a/algorithms/linfa-bayes/src/gaussian_nb.rs
+++ b/algorithms/linfa-bayes/src/gaussian_nb.rs
@@ -6,8 +6,9 @@ use ndarray_stats::QuantileExt;
use std::collections::HashMap;
use std::hash::Hash;
-use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
+use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
use crate::error::{NaiveBayesError, Result};
+use crate::filter;
use crate::hyperparams::{GaussianNbParams, GaussianNbValidParams};
#[cfg(feature = "serde")]
@@ -33,8 +34,7 @@ where
// Thin wrapper around the corresponding method of NaiveBayesValidParams
fn fit(&self, dataset: &DatasetBase, T>) -> Result {
- let model = NaiveBayesValidParams::fit(self, dataset, None)?;
- Ok(model.unwrap())
+ NaiveBayesValidParams::fit(self, dataset, None)
}
}
@@ -47,7 +47,7 @@ where
T: AsSingleTargets + Labels,
{
type ObjectIn = Option>;
- type ObjectOut = Option>;
+ type ObjectOut = GaussianNb;
fn fit_with(
&self,
@@ -115,7 +115,7 @@ where
info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
}
- Ok(Some(model))
+ Ok(model)
}
}
@@ -295,7 +295,7 @@ mod tests {
use super::{GaussianNb, NaiveBayes, Result};
use linfa::{
traits::{Fit, FitWith, Predict},
- DatasetView,
+ DatasetView, Error,
};
use crate::gaussian_nb::GaussianClassInfo;
@@ -334,6 +334,7 @@ mod tests {
let jll = fitted_clf.joint_log_likelihood(x.view());
+ // expected values from GaussianNB scikit-learn 1.6.1
let mut expected = HashMap::new();
expected.insert(
&1usize,
@@ -360,6 +361,27 @@ mod tests {
assert_eq!(jll, expected);
+ let expected_proba = array![
+ [1.00000000e+00, 2.31952358e-16],
+ [1.00000000e+00, 3.77513536e-11],
+ [1.00000000e+00, 2.31952358e-16],
+ [3.77513536e-11, 1.00000000e+00],
+ [2.31952358e-16, 1.00000000e+00],
+ [2.31952358e-16, 1.00000000e+00]
+ ];
+
+ let (y_pred_proba, classes) = fitted_clf.predict_proba(x.view());
+ assert_eq!(classes, vec![&1usize, &2]);
+ assert_abs_diff_eq!(expected_proba, y_pred_proba, epsilon = 1e-10);
+
+ let (y_pred_log_proba, classes) = fitted_clf.predict_log_proba(x.view());
+ assert_eq!(classes, vec![&1usize, &2]);
+ assert_abs_diff_eq!(
+ y_pred_proba.mapv(f64::ln),
+ y_pred_log_proba,
+ epsilon = 1e-10
+ );
+
Ok(())
}
@@ -381,8 +403,8 @@ mod tests {
.axis_chunks_iter(Axis(0), 2)
.zip(y.axis_chunks_iter(Axis(0), 2))
.map(|(a, b)| DatasetView::new(a, b))
- .fold(None, |current, d| clf.fit_with(current, &d).unwrap())
- .unwrap();
+ .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
+ .ok_or(Error::NotEnoughSamples)?;
let pred = model.predict(&x);
diff --git a/algorithms/linfa-bayes/src/hyperparams.rs b/algorithms/linfa-bayes/src/hyperparams.rs
index eda86665e..33fe5f550 100644
--- a/algorithms/linfa-bayes/src/hyperparams.rs
+++ b/algorithms/linfa-bayes/src/hyperparams.rs
@@ -188,3 +188,104 @@ impl ParamGuard for MultinomialNbParams {
Ok(self.0)
}
}
+
+/// A verified hyper-parameter set ready for the estimation of a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
+///
+/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model and [`BernoulliNbParams`](crate::hyperparams::BernoulliNbParams) for information on hyperparameters.
+#[derive(Debug, Clone, PartialEq)]
+pub struct BernoulliNbValidParams {
+ // Required for calculation stability
+ alpha: F,
+ // Threshold for binarization
+ binarize: Option,
+ // Phantom data for label type
+ label: PhantomData,
+}
+
+impl BernoulliNbValidParams {
+ /// Get the variance smoothing
+ pub fn alpha(&self) -> F {
+ self.alpha
+ }
+ /// Get the binarization threshold
+ pub fn binarize(&self) -> Option {
+ self.binarize
+ }
+}
+
+/// A hyper-parameter set during construction for a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
+///
+/// The parameter set can be verified into a
+/// [`BernoulliNbValidParams`](crate::hyperparams::BernoulliNbValidParams) by calling
+/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
+/// [Fit::fit](linfa::traits::Fit::fit) or
+/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitly verifies the parameter set
+/// prior to the model estimation and forwards any error.
+///
+/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model.
+///
+/// # Parameters
+/// | Name | Default | Purpose | Range |
+/// | :--- | :--- | :---| :--- |
+/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` |
+/// | [binarize](Self::binarize) | `0.0` | Threshold for binarization (mapping to booleans) of sample features. If `None`, input is presumed to already consist of binary vectors. | `(-inf, inf)` |
+///
+/// # Errors
+///
+/// The following errors can come from invalid hyper-parameters:
+///
+/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
+/// parameter is negative.
+///
+#[derive(Debug, Clone, PartialEq)]
+pub struct BernoulliNbParams(BernoulliNbValidParams);
+
+impl Default for BernoulliNbParams {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl BernoulliNbParams {
+ /// Create new [BernoulliNbParams] set with default values for its parameters
+ pub fn new() -> Self {
+ Self(BernoulliNbValidParams {
+ alpha: F::one(),
+ binarize: Some(F::zero()),
+ label: PhantomData,
+ })
+ }
+
+ /// Specifies the portion of the largest variance of all the features that
+ /// is added to the variance for calculation stability
+ pub fn alpha(mut self, alpha: F) -> Self {
+ self.0.alpha = alpha;
+ self
+ }
+
+ /// Set the binarization threshold
+ pub fn binarize(mut self, threshold: Option) -> Self {
+ self.0.binarize = threshold;
+ self
+ }
+}
+
+impl ParamGuard for BernoulliNbParams {
+ type Checked = BernoulliNbValidParams;
+ type Error = NaiveBayesError;
+
+ fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
+ if self.0.alpha.is_negative() {
+ Err(NaiveBayesError::InvalidSmoothing(
+ self.0.alpha.to_f64().unwrap(),
+ ))
+ } else {
+ Ok(&self.0)
+ }
+ }
+
+ fn check(self) -> Result {
+ self.check_ref()?;
+ Ok(self.0)
+ }
+}
diff --git a/algorithms/linfa-bayes/src/lib.rs b/algorithms/linfa-bayes/src/lib.rs
index eb74535ab..273a1e955 100644
--- a/algorithms/linfa-bayes/src/lib.rs
+++ b/algorithms/linfa-bayes/src/lib.rs
@@ -1,13 +1,104 @@
#![doc = include_str!("../README.md")]
mod base_nb;
+mod bernoulli_nb;
mod error;
mod gaussian_nb;
mod hyperparams;
mod multinomial_nb;
+pub use base_nb::NaiveBayes;
+pub use bernoulli_nb::BernoulliNb;
pub use error::{NaiveBayesError, Result};
pub use gaussian_nb::GaussianNb;
+pub use hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
pub use hyperparams::{GaussianNbParams, GaussianNbValidParams};
pub use hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
pub use multinomial_nb::MultinomialNb;
+
+use linfa::{Float, Label};
+use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
+
+#[cfg(feature = "serde")]
+use serde_crate::{Deserialize, Serialize};
+
+/// Histogram of class occurrences for multinomial and binomial parameter estimation
+#[derive(Debug, Default, Clone, PartialEq)]
+#[cfg_attr(
+ feature = "serde",
+ derive(Serialize, Deserialize),
+ serde(crate = "serde_crate")
+)]
+pub(crate) struct ClassHistogram {
+ class_count: usize,
+ prior: F,
+ feature_count: Array1,
+ feature_log_prob: Array1,
+}
+
+impl ClassHistogram {
+ // Update log probabilities of features given class
+ fn update_with_smoothing(&mut self, x_new: ArrayView2, alpha: F, total_count: bool) {
+ // If incoming data is empty no updates required
+ if x_new.nrows() == 0 {
+ return;
+ }
+
+ // unpack old class information
+ let ClassHistogram {
+ class_count,
+ feature_count,
+ feature_log_prob,
+ ..
+ } = self;
+
+ // count new feature occurrences
+ let feature_count_new: Array1 = x_new.sum_axis(Axis(0));
+
+ // if previous batch was empty, we send the new feature count calculated
+ if *class_count > 0 {
+ *feature_count = feature_count_new + feature_count.view();
+ } else {
+ *feature_count = feature_count_new;
+ }
+
+ // apply smoothing to feature counts
+ let feature_count_smoothed = feature_count.mapv(|x| x + alpha);
+
+ // compute total count (smoothed)
+ let count = if total_count {
+ F::cast(x_new.nrows()) + alpha * F::cast(2)
+ } else {
+ feature_count_smoothed.sum()
+ };
+
+ // compute log probabilities of each feature
+ *feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - count.ln());
+ // update class count
+ *class_count += x_new.nrows();
+ }
+}
+
+/// Returns a subset of x corresponding to the class specified by `ycondition`
+pub(crate) fn filter(
+ x: ArrayView2,
+ y: ArrayView1,
+ ycondition: &L,
+) -> Array2 {
+ // We identify the row numbers corresponding to the class we are interested in
+ let index = y
+ .into_iter()
+ .enumerate()
+ .filter(|&(_, y)| (*ycondition == *y))
+ .map(|(i, _)| i)
+ .collect::>();
+
+ // We subset x to only records corresponding to the class represented in `ycondition`
+ let mut xsubset = Array2::zeros((index.len(), x.ncols()));
+ index
+ .into_iter()
+ .enumerate()
+ .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
+
+ xsubset
+}
diff --git a/algorithms/linfa-bayes/src/multinomial_nb.rs b/algorithms/linfa-bayes/src/multinomial_nb.rs
index 1fc852d26..c06733e0d 100644
--- a/algorithms/linfa-bayes/src/multinomial_nb.rs
+++ b/algorithms/linfa-bayes/src/multinomial_nb.rs
@@ -1,13 +1,14 @@
use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
use linfa::traits::{Fit, FitWith, PredictInplace};
use linfa::{Float, Label};
-use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
+use ndarray::{Array1, ArrayBase, ArrayView2, Data, Ix2};
use std::collections::HashMap;
use std::hash::Hash;
-use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
+use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
use crate::error::{NaiveBayesError, Result};
use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
+use crate::{filter, ClassHistogram};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
@@ -31,8 +32,7 @@ where
type Object = MultinomialNb;
// Thin wrapper around the corresponding method of NaiveBayesValidParams
fn fit(&self, dataset: &DatasetBase, T>) -> Result {
- let model = NaiveBayesValidParams::fit(self, dataset, None)?;
- Ok(model.unwrap())
+ NaiveBayesValidParams::fit(self, dataset, None)
}
}
@@ -45,7 +45,7 @@ where
T: AsSingleTargets + Labels,
{
type ObjectIn = Option>;
- type ObjectOut = Option>;
+ type ObjectOut = MultinomialNb;
fn fit_with(
&self,
@@ -65,34 +65,31 @@ where
let yunique = dataset.labels();
for class in yunique {
- // We filter for records that correspond to the current class
+ // filter dataset for current class
let xclass = filter(x.view(), y.view(), &class);
- // We count the number of occurences of the class
- let nclass = xclass.nrows();
- // We compute the feature log probabilities and feature counts on the slice corresponding to the current class
- let class_info = model
+ // compute feature log probabilities and counts
+ model
.class_info
- .entry(class)
- .or_insert_with(MultinomialClassInfo::default);
- let (feature_log_prob, feature_count) =
- self.update_feature_log_prob(class_info, xclass.view());
- // We now update the total counts of each feature, feature log probabilities, and class count
- class_info.feature_log_prob = feature_log_prob;
- class_info.feature_count = feature_count;
- class_info.class_count += nclass;
+ .entry(class.clone())
+ .or_insert_with(ClassHistogram::default)
+ .update_with_smoothing(xclass.view(), self.alpha(), false);
+
+ dbg!(&model.class_info.get(&class));
}
- // We update the priors
+ // update priors
let class_count_sum = model
.class_info
.values()
.map(|x| x.class_count)
.sum::();
+
for info in model.class_info.values_mut() {
info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
}
- Ok(Some(model))
+
+ Ok(model)
}
}
@@ -110,49 +107,6 @@ where
}
}
-impl MultinomialNbValidParams
-where
- F: Float,
-{
- // Update log probabilities of features given class
- fn update_feature_log_prob(
- &self,
- info_old: &MultinomialClassInfo,
- x_new: ArrayView2,
- ) -> (Array1, Array1) {
- // Deconstruct old state
- let (count_old, feature_log_prob_old, feature_count_old) = (
- &info_old.class_count,
- &info_old.feature_log_prob,
- &info_old.feature_count,
- );
-
- // If incoming data is empty no updates required
- if x_new.nrows() == 0 {
- return (
- feature_log_prob_old.to_owned(),
- feature_count_old.to_owned(),
- );
- }
-
- let feature_count_new = x_new.sum_axis(Axis(0));
-
- // If previous batch was empty, we send the new feature count calculated
- let feature_count = if count_old > &0 {
- feature_count_old + feature_count_new
- } else {
- feature_count_new
- };
- // Apply smoothing to feature counts
- let feature_count_smoothed = feature_count.clone() + self.alpha();
- // Compute total count over all (smoothed) features
- let count = feature_count_smoothed.sum();
- // Compute log probabilities of each feature
- let feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - F::cast(count).ln());
- (feature_log_prob.to_owned(), feature_count.to_owned())
- }
-}
-
/// Fitted Multinomial Naive Bayes classifier.
///
/// See [MultinomialNbParams] for more information on the hyper-parameters.
@@ -206,20 +160,7 @@ where
)]
#[derive(Debug, Clone, PartialEq)]
pub struct MultinomialNb {
- class_info: HashMap>,
-}
-
-#[cfg_attr(
- feature = "serde",
- derive(Serialize, Deserialize),
- serde(crate = "serde_crate")
-)]
-#[derive(Debug, Default, Clone, PartialEq)]
-struct MultinomialClassInfo {
- class_count: usize,
- prior: F,
- feature_count: Array1,
- feature_log_prob: Array1,
+ class_info: HashMap>,
}
impl MultinomialNb {
@@ -253,10 +194,9 @@ mod tests {
use super::{MultinomialNb, NaiveBayes, Result};
use linfa::{
traits::{Fit, FitWith, Predict},
- DatasetView,
+ Dataset, DatasetView, Error,
};
- use crate::multinomial_nb::MultinomialClassInfo;
use crate::{MultinomialNbParams, MultinomialNbValidParams};
use approx::assert_abs_diff_eq;
use ndarray::{array, Axis};
@@ -266,23 +206,23 @@ mod tests {
fn autotraits() {
fn has_autotraits() {}
has_autotraits::>();
- has_autotraits::>();
has_autotraits::>();
has_autotraits::>();
}
#[test]
fn test_multinomial_nb() -> Result<()> {
- let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
- let y = array![1, 1, 1, 2, 2, 2];
+ let ds = Dataset::new(
+ array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]],
+ array![1, 1, 1, 2, 2, 2],
+ );
- let data = DatasetView::new(x.view(), y.view());
- let fitted_clf = MultinomialNb::params().fit(&data)?;
- let pred = fitted_clf.predict(&x);
+ let fitted_clf = MultinomialNb::params().fit(&ds)?;
+ let pred = fitted_clf.predict(ds.records());
- assert_abs_diff_eq!(pred, y);
+ assert_abs_diff_eq!(pred, ds.targets());
- let jll = fitted_clf.joint_log_likelihood(x.view());
+ let jll = fitted_clf.joint_log_likelihood(ds.records().view());
let mut expected = HashMap::new();
// Computed with sklearn.naive_bayes.MultinomialNB
expected.insert(
@@ -327,8 +267,8 @@ mod tests {
.axis_chunks_iter(Axis(0), 2)
.zip(y.axis_chunks_iter(Axis(0), 2))
.map(|(a, b)| DatasetView::new(a, b))
- .fold(None, |current, d| clf.fit_with(current, &d).unwrap())
- .unwrap();
+ .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
+ .ok_or(Error::NotEnoughSamples)?;
let pred = model.predict(&x);
diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs
index 846aea8d9..b537af64f 100644
--- a/algorithms/linfa-clustering/src/k_means/algorithm.rs
+++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs
@@ -46,7 +46,7 @@ use serde_crate::{Deserialize, Serialize};
/// There are three steps in the standard algorithm:
/// - initialisation step: select initial centroids using one of our provided algorithms.
/// - assignment step: assign each observation to the nearest cluster
-/// (minimum distance between the observation and the cluster's centroid);
+/// (minimum distance between the observation and the cluster's centroid);
/// - update step: recompute the centroid of each cluster.
///
/// The initialisation step is a one-off, done at the very beginning.
diff --git a/algorithms/linfa-clustering/src/lib.rs b/algorithms/linfa-clustering/src/lib.rs
index a9aa72858..7418bc64a 100644
--- a/algorithms/linfa-clustering/src/lib.rs
+++ b/algorithms/linfa-clustering/src/lib.rs
@@ -16,7 +16,7 @@
//! * [K-Means](KMeans)
//! * [DBSCAN](Dbscan)
//! * [Approximated DBSCAN](AppxDbscan) (Currently an alias for DBSCAN, due to its superior
-//! performance)
+//! performance)
//! * [Gaussian-Mixture-Model](GaussianMixtureModel)
//! * [OPTICS](OpticsAnalysis)
//!
diff --git a/algorithms/linfa-kernel/src/lib.rs b/algorithms/linfa-kernel/src/lib.rs
index b2361ab8c..a2bb37997 100644
--- a/algorithms/linfa-kernel/src/lib.rs
+++ b/algorithms/linfa-kernel/src/lib.rs
@@ -398,7 +398,7 @@ impl
///
/// A new dataset with:
/// - records: a kernel build from `x.records()` according to the parameters on which
- /// this method is called
+ /// this method is called
/// - targets: same as `x.targets()`
///
/// ## Panics
@@ -425,7 +425,7 @@ impl<'a, F: Float, L: 'a, T: AsTargets + FromTargetArray<'a>, N: Neare
///
/// A new dataset with:
/// - records: a kernel build from `x.records()` according to the parameters on which
- /// this method is called
+ /// this method is called
/// - targets: same as `x.targets()`
///
/// ## Panics
@@ -460,7 +460,7 @@ impl<
///
/// A new dataset with:
/// - records: a kernel build from `x.records()` according to the parameters on which
- /// this method is called
+ /// this method is called
/// - targets: a slice of `x.targets()`
///
/// ## Panics
diff --git a/algorithms/linfa-linear/src/ols.rs b/algorithms/linfa-linear/src/ols.rs
index 25968719d..3f915c0cc 100644
--- a/algorithms/linfa-linear/src/ols.rs
+++ b/algorithms/linfa-linear/src/ols.rs
@@ -243,9 +243,9 @@ mod tests {
/// We can't fit a line through three points in general
/// - in this case we should find the solution that minimizes
- /// the squares. Fitting a line with intercept through the
- /// points (0, 0), (1, 0), (2, 2) has the least-squares solution
- /// f(x) = -1./3. + x
+ /// the squares. Fitting a line with intercept through the
+ /// points (0, 0), (1, 0), (2, 2) has the least-squares solution
+ /// f(x) = -1./3. + x
#[test]
fn fits_least_squares_line_through_three_dots() {
let lin_reg = LinearRegression::new();
diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs
index e051e1722..1d0522b55 100644
--- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs
+++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs
@@ -51,11 +51,11 @@ impl SerdeRegex {
/// they will still be used by the [CountVectorizer](crate::CountVectorizer) to transform any text to be examined.
///
/// * `split_regex`: the regex espression used to split decuments into tokens. Defaults to r"\\b\\w\\w+\\b", which selects "words", using whitespaces and
-/// punctuation symbols as separators.
+/// punctuation symbols as separators.
/// * `convert_to_lowercase`: if true, all documents used for fitting will be converted to lowercase. Defaults to `true`.
/// * `n_gram_range`: if set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
-/// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
-/// regex used fpr splitting the documents. The default value is `(1,1)`.
+/// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
+/// regex used fpr splitting the documents. The default value is `(1,1)`.
/// * `normalize`: if true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization. Defaults to `true`.
/// * `document_frequency`: specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy. Defaults to `(0., 1.)` (i.e. 0% minimum and 100% maximum)
/// * `stopwords`: optional list of entries to be excluded from the generated vocabulary. Defaults to `None`
diff --git a/algorithms/linfa-trees/benches/decision_tree.rs b/algorithms/linfa-trees/benches/decision_tree.rs
index fdaeeb0c9..075149402 100644
--- a/algorithms/linfa-trees/benches/decision_tree.rs
+++ b/algorithms/linfa-trees/benches/decision_tree.rs
@@ -39,6 +39,7 @@ fn decision_tree_bench(c: &mut Criterion) {
Array2::random_using((n_classes, n_features), Uniform::new(-30., 30.), &mut rng);
let train_x = generate_blobs(¢roids, *n, &mut rng);
+ #[allow(clippy::manual_repeat_n)]
let train_y: Array1 = (0..n_classes)
.flat_map(|x| std::iter::repeat(x).take(*n).collect::>())
.collect::>();
diff --git a/algorithms/linfa-trees/src/decision_trees/algorithm.rs b/algorithms/linfa-trees/src/decision_trees/algorithm.rs
index 5c47fdcee..4ce54ebba 100644
--- a/algorithms/linfa-trees/src/decision_trees/algorithm.rs
+++ b/algorithms/linfa-trees/src/decision_trees/algorithm.rs
@@ -431,7 +431,7 @@ impl TreeNode {
/// ### Structure
/// A decision tree structure is a binary tree where:
/// * Each internal node specifies a decision, represented by a choice of a feature and a "split value" such that all observations for which
-/// `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree.
+/// `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree.
///
/// * leaf nodes make predictions, and their prediction is the most popular label in the node
///
diff --git a/src/correlation.rs b/src/correlation.rs
index 4962e787b..31a01c1a7 100644
--- a/src/correlation.rs
+++ b/src/correlation.rs
@@ -128,7 +128,7 @@ impl PearsonCorrelation {
///
/// * `dataset`: Data for the correlation analysis
/// * `num_iter`: optionally number of iterations of the p-value test, if none then no p-value
- /// are calculate
+ /// are calculate
///
/// # Example
///
diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs
index 0d019174b..a43944057 100644
--- a/src/dataset/impl_dataset.rs
+++ b/src/dataset/impl_dataset.rs
@@ -712,8 +712,8 @@ where
/// - `k`: the number of folds to apply to the dataset
/// - `params`: the desired parameters for the fittable algorithm at hand
/// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model`
- /// that will be used to produce the trained model for each fold. The training data given in input
- /// won't outlive the closure.
+ /// that will be used to produce the trained model for each fold. The training data given in input
+ /// won't outlive the closure.
///
/// ## Returns
///
@@ -826,9 +826,9 @@ where
/// - `k`: the number of folds to apply
/// - `parameters`: a list of models to compare
/// - `eval`: closure used to evaluate the performance of each trained model. This closure is
- /// called on the model output and validation targets of each fold and outputs the performance
- /// score for each target. For single-target dataset the signature is `(Array1, Array1) ->
- /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
+ /// called on the model output and validation targets of each fold and outputs the performance
+ /// score for each target. For single-target dataset the signature is `(Array1, Array1) ->
+ /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
///
/// ### Returns
///
diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs
index 21d50aeff..2f275a553 100644
--- a/src/dataset/mod.rs
+++ b/src/dataset/mod.rs
@@ -161,7 +161,7 @@ impl Deref for Pr {
/// # Fields
///
/// * `records`: a two-dimensional matrix with dimensionality (nsamples, nfeatures), in case of
-/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse
+/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse
/// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets)
/// * `weights`: optional weights for each sample with dimensionality (nsamples)
/// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures)
@@ -171,7 +171,7 @@ impl Deref for Pr {
///
/// * `R: Records`: generic over feature matrices or kernel matrices
/// * `T`: generic over any `ndarray` matrix which can be used as targets. The `AsTargets` trait
-/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs`
+/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs`
#[derive(Debug, Clone, PartialEq)]
pub struct DatasetBase
where