-
Notifications
You must be signed in to change notification settings - Fork 262
Add mean pooling strategy for Modernbert classifier #616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -202,3 +202,61 @@ fn test_modernbert_classification() -> Result<()> { | |||
|
|||
Ok(()) | |||
} | |||
|
|||
#[test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a bit of trouble running this test (unable to download the HF model, hitting a 429, just lacking a HF token) and not super pleased with this test but it;s just to test that both cls/mean produce different outputs as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just testing with the reranker-ModernBERT-large-gooaq-bce model directly, which is a mean classifier-pooling, instead of gte-reranker
model with mean pooling?
I think we can check whether classifier pooling correctly works by running the reranker-ModernBERT-gooqa-bce
model, while cls
classifier pooling has already been verified in the above test! (implicitly assume cls and mean will behave differently)
unable to download the HF model, hitting a 429,
when I faced that issue, in my case, I manually download the model via git or huggingface-cli on my local, and then set HUGGINGFACE_HUB_CACHE
to point to that local path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good idea, will try this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It couldn't find from the cache but I managed to just export a hugginface token, all good!
successes:
test_modernbert_classification_mean_pooling
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 3 filtered out; finished in 3.94s
@@ -260,7 +260,15 @@ impl FlashModernBertModel { | |||
|
|||
let (pool, classifier) = match model_type { | |||
ModelType::Classifier => { | |||
let pool = Pool::Cls; | |||
let pool = if let Some(pooling) = &config.classifier_pooling { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might make more sense to do this like in below:
ModelType::Classifier(pool) => {
if pool == Pool::Splade {
candle::bail!("`splade` is not supported for ModernBert")
}
if pool == Pool::LastToken {
candle::bail!("`LastToken` is not supported for ModernBert")
}
}
But this would be a bigger change that I'm not as confident about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also not confident about this, maybe we could consider std::str::FromStr
for the Pool
enum like below. It would be great if the maintainers could provide some guidance or feedback on this :)
impl std::str::FromStr for Pool {
fn from_str(s: &str) -> Result<Self, Err> {
match s {
"cls" => Ok(Pool::Cls),
"mean" => Ok(Pool::Mean),
"splade" => Ok(Pool::Splade),
"last_token" => Ok(Pool::LastToken),
_ => Err(()),
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm just a passerby lol. I leaved some comments, to hope if it could help you in some way :) thanks
@@ -202,3 +202,61 @@ fn test_modernbert_classification() -> Result<()> { | |||
|
|||
Ok(()) | |||
} | |||
|
|||
#[test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just testing with the reranker-ModernBERT-large-gooaq-bce model directly, which is a mean classifier-pooling, instead of gte-reranker
model with mean pooling?
I think we can check whether classifier pooling correctly works by running the reranker-ModernBERT-gooqa-bce
model, while cls
classifier pooling has already been verified in the above test! (implicitly assume cls and mean will behave differently)
unable to download the HF model, hitting a 429,
when I faced that issue, in my case, I manually download the model via git or huggingface-cli on my local, and then set HUGGINGFACE_HUB_CACHE
to point to that local path.
@@ -260,7 +260,15 @@ impl FlashModernBertModel { | |||
|
|||
let (pool, classifier) = match model_type { | |||
ModelType::Classifier => { | |||
let pool = Pool::Cls; | |||
let pool = if let Some(pooling) = &config.classifier_pooling { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also not confident about this, maybe we could consider std::str::FromStr
for the Pool
enum like below. It would be great if the maintainers could provide some guidance or feedback on this :)
impl std::str::FromStr for Pool {
fn from_str(s: &str) -> Result<Self, Err> {
match s {
"cls" => Ok(Pool::Cls),
"mean" => Ok(Pool::Mean),
"splade" => Ok(Pool::Splade),
"last_token" => Ok(Pool::LastToken),
_ => Err(()),
}
}
}
What does this PR do?
Apologies in advance, new this this area and Rust 😅.
I noticed differences between Transformers and TEI libraries when running rerankers with a
mean
classifier pooling strategy. More details here.Fixes #615
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Not sure who might be best to review 😄 @OlivierDehaene OR @Narsil