Skip to content

Commit a599e5b

Browse files
relfwildartLorenz Schmidt
authored andcommitted
Add bernoulli naive bayes implementation (rust-ml#388)
* Add bernoulli Naive Bayes algorithm * Remove optional from `FitWith` trait implementation * Move class histogram to root path; Improve example * Make NaiveBayes public * Add predict_log_proba and predict_proba * Review * Linting * Revert lint to preserve MSRV * Add serde for ClassHistogram * Bump MSRV to 1.82.0 --------- Co-authored-by: Art Wild <[email protected]> Co-authored-by: Lorenz Schmidt <[email protected]>
1 parent 776e311 commit a599e5b

File tree

20 files changed

+669
-147
lines changed

20 files changed

+669
-147
lines changed

.github/workflows/checking.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
toolchain:
13-
- 1.81.0
13+
- 1.82.0
1414
- stable
1515
- nightly
1616
os:

.github/workflows/testing.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
toolchain:
13-
- 1.81.0
13+
- 1.82.0
1414
- stable
1515
os:
1616
- ubuntu-latest
@@ -35,7 +35,7 @@ jobs:
3535
fail-fast: false
3636
matrix:
3737
toolchain:
38-
- 1.81.0
38+
- 1.82.0
3939
- stable
4040
os:
4141
- ubuntu-latest

algorithms/linfa-bayes/README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
`linfa-bayes` currently provides an implementation of the following methods:
1212

1313
- Gaussian Naive Bayes ([`GaussianNb`])
14-
- Multinomial Naive Nayes ([`MultinomialNb`]))
14+
- Multinomial Naive Nayes ([`MultinomialNb`])
15+
- Bernoulli Naive Nayes ([`BernoulliNb`])
1516

1617
## Examples
1718

@@ -95,3 +96,43 @@ println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
9596
# Result::Ok(())
9697
```
9798
</details>
99+
100+
To run Bernoulli Naive Bayes example, use:
101+
102+
```bash
103+
$ cargo run --example winequality_bernoulli --release
104+
```
105+
106+
<details>
107+
<summary style="cursor: pointer; display:list-item;">
108+
Show source code
109+
</summary>
110+
111+
```rust, no_run
112+
use linfa::metrics::ToConfusionMatrix;
113+
use linfa::traits::{Fit, Predict};
114+
use linfa_bayes::{BernoulliNb, Result};
115+
116+
// Read in the dataset and convert targets to binary data
117+
let (train, valid) = linfa_datasets::winequality()
118+
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
119+
.split_with_ratio(0.9);
120+
121+
// Train the model
122+
let model = BernoulliNb::params().fit(&train)?;
123+
124+
// Predict the validation dataset
125+
let pred = model.predict(&valid);
126+
127+
// Construct confusion matrix
128+
let cm = pred.confusion_matrix(&valid)?;
129+
// classes | bad | good
130+
// bad | 142 | 0
131+
// good | 17 | 0
132+
133+
// accuracy 0.8930818, MCC NaN
134+
println!("{:?}", cm);
135+
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
136+
# Result::Ok(())
137+
```
138+
</details>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use linfa::metrics::ToConfusionMatrix;
2+
use linfa::traits::{Fit, Predict};
3+
use linfa_bayes::{BernoulliNb, Result};
4+
5+
fn main() -> Result<()> {
6+
// Read in the dataset and convert targets to binary data
7+
let (train, valid) = linfa_datasets::winequality()
8+
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
9+
.split_with_ratio(0.9);
10+
11+
// Train the model
12+
let model = BernoulliNb::params().fit(&train)?;
13+
14+
// Predict the validation dataset
15+
let pred = model.predict(&valid);
16+
17+
// Construct confusion matrix
18+
let cm = pred.confusion_matrix(&valid)?;
19+
// classes | bad | good
20+
// bad | 142 | 0
21+
// good | 17 | 0
22+
23+
// accuracy 0.8930818, MCC
24+
println!("{:?}", cm);
25+
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
26+
27+
Ok(())
28+
}

algorithms/linfa-bayes/src/base_nb.rs

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
1+
use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
22
use ndarray_stats::QuantileExt;
33
use std::collections::HashMap;
44

@@ -8,13 +8,17 @@ use linfa::traits::FitWith;
88
use linfa::{Float, Label};
99

1010
// Trait computing predictions for fitted Naive Bayes models
11-
pub(crate) trait NaiveBayes<'a, F, L>
11+
pub trait NaiveBayes<'a, F, L>
1212
where
1313
F: Float,
1414
L: Label + Ord,
1515
{
16+
/// Compute the unnormalized posterior log probabilities.
17+
/// The result is returned as an HashMap indexing log probabilities for each samples (eg x rows) by classes
18+
/// (eg jll\[class\] -> (n_samples,) array)
1619
fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>>;
1720

21+
#[doc(hidden)]
1822
fn predict_inplace<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
1923
assert_eq!(
2024
x.nrows(),
@@ -45,6 +49,40 @@ where
4549
classes[i].clone()
4650
});
4751
}
52+
53+
/// Compute log-probability estimates for each sample wrt classes.
54+
/// The columns corresponds to classes in sorted order returned as the second output.
55+
fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
56+
let log_likelihood = self.joint_log_likelihood(x);
57+
58+
let mut classes = log_likelihood.keys().cloned().collect::<Vec<_>>();
59+
classes.sort();
60+
61+
let n_samples = x.nrows();
62+
let n_classes = log_likelihood.len();
63+
let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
64+
65+
Zip::from(log_prob_mat.columns_mut())
66+
.and(&classes)
67+
.for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));
68+
69+
let log_prob_x = log_prob_mat
70+
.mapv(|x| x.exp())
71+
.sum_axis(Axis(1))
72+
.mapv(|x| x.ln())
73+
.into_shape((n_samples, 1))
74+
.unwrap();
75+
76+
(log_prob_mat - log_prob_x, classes)
77+
}
78+
79+
/// Compute probability estimates for each sample wrt classes.
80+
/// The columns corresponds to classes in sorted order returned as the second output.
81+
fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
82+
let (log_prob_mat, classes) = self.predict_log_proba(x);
83+
84+
(log_prob_mat.mapv(|v| v.exp()), classes)
85+
}
4886
}
4987

5088
// Common functionality for hyper-parameter sets of Naive Bayes models ready for estimation
@@ -68,27 +106,3 @@ where
68106
self.fit_with(model_none, dataset)
69107
}
70108
}
71-
72-
// Returns a subset of x corresponding to the class specified by `ycondition`
73-
pub fn filter<F: Float, L: Label + Ord>(
74-
x: ArrayView2<F>,
75-
y: ArrayView1<L>,
76-
ycondition: &L,
77-
) -> Array2<F> {
78-
// We identify the row numbers corresponding to the class we are interested in
79-
let index = y
80-
.into_iter()
81-
.enumerate()
82-
.filter(|(_, y)| (*ycondition == **y))
83-
.map(|(i, _)| i)
84-
.collect::<Vec<_>>();
85-
86-
// We subset x to only records corresponding to the class represented in `ycondition`
87-
let mut xsubset = Array2::zeros((index.len(), x.ncols()));
88-
index
89-
.into_iter()
90-
.enumerate()
91-
.for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
92-
93-
xsubset
94-
}

0 commit comments

Comments
 (0)