Skip to content

Add bernoulli naive bayes implementation #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.81.0
- 1.82.0
- stable
- nightly
os:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.81.0
- 1.82.0
- stable
os:
- ubuntu-latest
Expand All @@ -35,7 +35,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.81.0
- 1.82.0
- stable
os:
- ubuntu-latest
Expand Down
43 changes: 42 additions & 1 deletion algorithms/linfa-bayes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
`linfa-bayes` currently provides an implementation of the following methods:

- Gaussian Naive Bayes ([`GaussianNb`])
- Multinomial Naive Nayes ([`MultinomialNb`]))
- Multinomial Naive Nayes ([`MultinomialNb`])
- Bernoulli Naive Nayes ([`BernoulliNb`])

## Examples

Expand Down Expand Up @@ -95,3 +96,43 @@ println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
# Result::Ok(())
```
</details>

To run Bernoulli Naive Bayes example, use:

```bash
$ cargo run --example winequality_bernoulli --release
```

<details>
<summary style="cursor: pointer; display:list-item;">
Show source code
</summary>

```rust, no_run
use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::{BernoulliNb, Result};

// Read in the dataset and convert targets to binary data
let (train, valid) = linfa_datasets::winequality()
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
.split_with_ratio(0.9);

// Train the model
let model = BernoulliNb::params().fit(&train)?;

// Predict the validation dataset
let pred = model.predict(&valid);

// Construct confusion matrix
let cm = pred.confusion_matrix(&valid)?;
// classes | bad | good
// bad | 142 | 0
// good | 17 | 0

// accuracy 0.8930818, MCC NaN
println!("{:?}", cm);
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
# Result::Ok(())
```
</details>
28 changes: 28 additions & 0 deletions algorithms/linfa-bayes/examples/winequality_bernouilli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::{BernoulliNb, Result};

fn main() -> Result<()> {
// Read in the dataset and convert targets to binary data
let (train, valid) = linfa_datasets::winequality()
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
.split_with_ratio(0.9);

// Train the model
let model = BernoulliNb::params().fit(&train)?;

// Predict the validation dataset
let pred = model.predict(&valid);

// Construct confusion matrix
let cm = pred.confusion_matrix(&valid)?;
// classes | bad | good
// bad | 142 | 0
// good | 17 | 0

// accuracy 0.8930818, MCC
println!("{:?}", cm);
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());

Ok(())
}
66 changes: 40 additions & 26 deletions algorithms/linfa-bayes/src/base_nb.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
use ndarray_stats::QuantileExt;
use std::collections::HashMap;

Expand All @@ -8,13 +8,17 @@
use linfa::{Float, Label};

// Trait computing predictions for fitted Naive Bayes models
pub(crate) trait NaiveBayes<'a, F, L>
pub trait NaiveBayes<'a, F, L>
where
F: Float,
L: Label + Ord,
{
/// Compute the unnormalized posterior log probabilities.
/// The result is returned as an HashMap indexing log probabilities for each samples (eg x rows) by classes
/// (eg jll\[class\] -> (n_samples,) array)
fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>>;

#[doc(hidden)]
fn predict_inplace<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
assert_eq!(
x.nrows(),
Expand Down Expand Up @@ -45,6 +49,40 @@
classes[i].clone()
});
}

/// Compute log-probability estimates for each sample wrt classes.
/// The columns corresponds to classes in sorted order returned as the second output.
fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
let log_likelihood = self.joint_log_likelihood(x);

let mut classes = log_likelihood.keys().cloned().collect::<Vec<_>>();
classes.sort();

Check warning on line 59 in algorithms/linfa-bayes/src/base_nb.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-bayes/src/base_nb.rs#L58-L59

Added lines #L58 - L59 were not covered by tests

let n_samples = x.nrows();
let n_classes = log_likelihood.len();
let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));

Check warning on line 63 in algorithms/linfa-bayes/src/base_nb.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-bayes/src/base_nb.rs#L61-L63

Added lines #L61 - L63 were not covered by tests

Zip::from(log_prob_mat.columns_mut())
.and(&classes)

Check warning on line 66 in algorithms/linfa-bayes/src/base_nb.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-bayes/src/base_nb.rs#L65-L66

Added lines #L65 - L66 were not covered by tests
.for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));

let log_prob_x = log_prob_mat
.mapv(|x| x.exp())
.sum_axis(Axis(1))
.mapv(|x| x.ln())
.into_shape((n_samples, 1))

Check warning on line 73 in algorithms/linfa-bayes/src/base_nb.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-bayes/src/base_nb.rs#L70-L73

Added lines #L70 - L73 were not covered by tests
.unwrap();

(log_prob_mat - log_prob_x, classes)
}

/// Compute probability estimates for each sample wrt classes.
/// The columns corresponds to classes in sorted order returned as the second output.
fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
let (log_prob_mat, classes) = self.predict_log_proba(x);

(log_prob_mat.mapv(|v| v.exp()), classes)
}
}

// Common functionality for hyper-parameter sets of Naive Bayes models ready for estimation
Expand All @@ -68,27 +106,3 @@
self.fit_with(model_none, dataset)
}
}

// Returns a subset of x corresponding to the class specified by `ycondition`
pub fn filter<F: Float, L: Label + Ord>(
x: ArrayView2<F>,
y: ArrayView1<L>,
ycondition: &L,
) -> Array2<F> {
// We identify the row numbers corresponding to the class we are interested in
let index = y
.into_iter()
.enumerate()
.filter(|(_, y)| (*ycondition == **y))
.map(|(i, _)| i)
.collect::<Vec<_>>();

// We subset x to only records corresponding to the class represented in `ycondition`
let mut xsubset = Array2::zeros((index.len(), x.ncols()));
index
.into_iter()
.enumerate()
.for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));

xsubset
}
Loading
Loading