Skip to content

Commit f21a638

Browse files
authored
Fix the weight name in GTEClassificationHead (#606)
1 parent 61c070a commit f21a638

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

backends/candle/src/models/gte.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ impl GTEMLP {
193193
let up_gate_proj_weight = vb
194194
.pp("up_gate_proj")
195195
.get((intermediate_size * 2, config.hidden_size), "weight")?;
196-
197196
let up_gate_proj = Linear::new(up_gate_proj_weight, None, None);
198197

199198
let down_proj_weight = vb
@@ -216,16 +215,12 @@ impl GTEMLP {
216215

217216
let up_gate_states = self.up_gate_proj.forward(hidden_states)?;
218217
let up_states = up_gate_states.narrow(D::Minus1, 0, self.intermediate_size)?;
219-
let gate_states =
220-
up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
221218

222-
let gate_states = match self.act {
223-
HiddenAct::Gelu => gate_states.gelu(),
224-
HiddenAct::Relu => gate_states.relu(),
225-
HiddenAct::Swiglu => gate_states.silu(),
226-
}?;
219+
let gate =
220+
up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
221+
let gate = self.act.forward(&gate)?;
227222

228-
self.down_proj.forward(&(gate_states * up_states)?)
223+
self.down_proj.forward(&(gate * up_states)?)
229224
}
230225
}
231226

@@ -288,22 +283,25 @@ pub struct GTEClassificationHead {
288283
}
289284

290285
impl GTEClassificationHead {
291-
#[allow(dead_code)]
286+
fn inner_load(vb: VarBuilder, config: &GTEConfig) -> Option<Linear> {
287+
let pooler_weight = vb
288+
.pp("pooler.dense")
289+
.get((config.hidden_size, config.hidden_size), "weight")
290+
.ok()?;
291+
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias").ok()?;
292+
let pooler = Linear::new(pooler_weight, Some(pooler_bias), None);
293+
294+
Some(pooler)
295+
}
296+
292297
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
293298
let n_classes = match &config.id2label {
294299
None => candle::bail!("`id2label` must be set for classifier models"),
295300
Some(id2label) => id2label.len(),
296301
};
297302

298-
let pooler = if let Ok(pooler_weight) = vb
299-
.pp("pooler.dense")
300-
.get((config.hidden_size, config.hidden_size), "weight")
301-
{
302-
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
303-
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
304-
} else {
305-
None
306-
};
303+
let pooler =
304+
Self::inner_load(vb.pp("new"), config).or_else(|| Self::inner_load(vb.clone(), config));
307305

308306
let classifier_weight = vb
309307
.pp("classifier")
@@ -322,13 +320,15 @@ impl GTEClassificationHead {
322320
let _enter = self.span.enter();
323321

324322
let mut hidden_states = hidden_states.unsqueeze(1)?;
323+
325324
if let Some(pooler) = self.pooler.as_ref() {
326325
hidden_states = pooler.forward(&hidden_states)?;
327326
hidden_states = hidden_states.tanh()?;
328327
}
329328

330329
let hidden_states = self.classifier.forward(&hidden_states)?;
331330
let hidden_states = hidden_states.squeeze(1)?;
331+
332332
Ok(hidden_states)
333333
}
334334
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
---
22
source: backends/candle/tests/test_flash_gte.rs
3-
assertion_line: 83
43
expression: predictions_single
54
---
6-
- - 0.050048828
5+
- - -0.7426758
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
source: backends/candle/tests/test_gte.rs
3+
expression: predictions_single
4+
---
5+
- - -0.74173266

backends/candle/tests/test_gte.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
mod common;
22

3-
use crate::common::{sort_embeddings, SnapshotEmbeddings};
3+
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
44
use anyhow::Result;
5-
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
5+
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
66
use text_embeddings_backend_candle::CandleBackend;
77
use text_embeddings_backend_core::{Backend, ModelType, Pool};
88

@@ -137,3 +137,32 @@ fn test_snowflake_gte() -> Result<()> {
137137

138138
Ok(())
139139
}
140+
141+
#[test]
142+
#[serial_test::serial]
143+
fn test_gte_classification() -> Result<()> {
144+
let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?;
145+
let tokenizer = load_tokenizer(&model_root)?;
146+
147+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
148+
149+
let input_single = batch(
150+
vec![tokenizer
151+
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
152+
.unwrap()],
153+
[0].to_vec(),
154+
vec![],
155+
);
156+
157+
let predictions: Vec<Vec<f32>> = backend
158+
.predict(input_single)?
159+
.into_iter()
160+
.map(|(_, v)| v)
161+
.collect();
162+
let predictions_single = SnapshotScores::from(predictions);
163+
164+
let matcher = relative_matcher();
165+
insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher);
166+
167+
Ok(())
168+
}

0 commit comments

Comments
 (0)