Skip to content

Commit dbbbf24

Browse files
committed
Add different pooling stategies for Modernbert classifier
1 parent 5bcc41f commit dbbbf24

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

backends/candle/src/models/flash_modernbert.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,15 @@ impl FlashModernBertModel {
260260

261261
let (pool, classifier) = match model_type {
262262
ModelType::Classifier => {
263-
let pool = Pool::Cls;
263+
let pool = if let Some(pooling) = &config.classifier_pooling {
264+
match pooling.as_str() {
265+
"cls" => Pool::Cls,
266+
"mean" => Pool::Mean,
267+
_ => Pool::Cls,
268+
}
269+
} else {
270+
Pool::Cls
271+
};
264272

265273
let classifier: Box<dyn ClassificationHead + Send> =
266274
Box::new(ModernBertClassificationHead::load(vb.clone(), config)?);

backends/candle/src/models/modernbert.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,15 @@ impl ModernBertModel {
484484
pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result<Self> {
485485
let (pool, classifier) = match model_type {
486486
ModelType::Classifier => {
487-
let pool = Pool::Cls;
487+
let pool = if let Some(pooling) = &config.classifier_pooling {
488+
match pooling.as_str() {
489+
"cls" => Pool::Cls,
490+
"mean" => Pool::Mean,
491+
_ => Pool::Cls,
492+
}
493+
} else {
494+
Pool::Cls
495+
};
488496

489497
let classifier: Box<dyn ClassificationHead + Send> =
490498
Box::new(ModernBertClassificationHead::load(vb.clone(), config)?);

backends/candle/tests/test_modernbert.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,61 @@ fn test_modernbert_classification() -> Result<()> {
202202

203203
Ok(())
204204
}
205+
206+
#[test]
207+
#[serial_test::serial]
208+
fn test_modernbert_classifier_pooling_strategy() -> Result<()> {
209+
let model_root = download_artifacts("Alibaba-NLP/gte-reranker-modernbert-base", None)?;
210+
let tokenizer = load_tokenizer(&model_root)?;
211+
212+
let config_path = model_root.join("config.json");
213+
let original_config: serde_json::Value = {
214+
let config_content = fs::read_to_string(&config_path)?;
215+
serde_json::from_str(&config_content)?
216+
};
217+
218+
assert_eq!(original_config["classifier_pooling"], "mean");
219+
220+
let input_single = batch(
221+
vec![tokenizer
222+
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
223+
.unwrap()],
224+
[0].to_vec(),
225+
vec![],
226+
);
227+
228+
let backend_mean = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
229+
let predictions_mean: Vec<Vec<f32>> = backend_mean
230+
.predict(input_single.clone())?
231+
.into_iter()
232+
.map(|(_, v)| v)
233+
.collect();
234+
235+
let mut config = original_config.clone();
236+
config["classifier_pooling"] = json!("cls");
237+
fs::write(&config_path, serde_json::to_string_pretty(&config)?)?;
238+
239+
let backend_cls = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
240+
let predictions_cls: Vec<Vec<f32>> = backend_cls
241+
.predict(input_single)?
242+
.into_iter()
243+
.map(|(_, v)| v)
244+
.collect();
245+
246+
fs::write(&config_path, serde_json::to_string_pretty(&original_config)?)?;
247+
248+
assert_ne!(
249+
predictions_mean[0], predictions_cls[0],
250+
"Mean and CLS pooling should produce different predictions"
251+
);
252+
253+
let matcher = relative_matcher();
254+
let predictions_snapshot = SnapshotScores::from(predictions_mean);
255+
insta::assert_yaml_snapshot!(
256+
"modernbert_classifier_mean_pooling",
257+
predictions_snapshot,
258+
&matcher
259+
);
260+
261+
Ok(())
262+
}

0 commit comments

Comments
 (0)