Skip to content

Commit f7aa35b

Browse files
lance-milesLance Mileskozistr
authored
Fix Qwen3-Embedding batch vs single inference inconsistency (#648)
Co-authored-by: Lance Miles <[email protected]> Co-authored-by: Hyeongchan Kim <[email protected]>
1 parent cad4b55 commit f7aa35b

File tree

3 files changed

+4093
-4079
lines changed

3 files changed

+4093
-4079
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -487,21 +487,23 @@ impl Qwen3Model {
487487
let seq_length = end - start;
488488
input_lengths.push(seq_length);
489489

490-
for j in start..end {
491-
input_ids.push(batch.input_ids[j]);
492-
position_ids.push(batch.position_ids[j]);
493-
attention_bias.push(0.0);
494-
}
495-
490+
// Left padding for Qwen3-Embedding (pad at the beginning)
496491
let padding = max_length - seq_length;
497492
if padding > 0 {
498493
masking = true;
499494
for _ in 0..padding {
500-
input_ids.insert(start, self.pad_token_id);
501-
position_ids.insert(start, 0);
502-
attention_bias.insert(start, f32::MIN);
495+
input_ids.push(self.pad_token_id);
496+
position_ids.push(0);
497+
attention_bias.push(f32::MIN);
503498
}
504499
}
500+
501+
// Then add the actual sequence
502+
for j in start..end {
503+
input_ids.push(batch.input_ids[j]);
504+
position_ids.push(batch.position_ids[j]);
505+
attention_bias.push(0.0);
506+
}
505507
}
506508

507509
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
@@ -533,7 +535,15 @@ impl Qwen3Model {
533535
)?;
534536
let input_lengths = vec![batch.input_ids.len()];
535537

536-
(input_ids, position_ids, input_lengths, None)
538+
let seq_len = batch.input_ids.len();
539+
// Create attention bias for causal masking even for single sequences
540+
let attention_bias = Tensor::zeros(
541+
(1, self.num_attention_heads, seq_len, seq_len),
542+
candle::DType::F32,
543+
&self.device,
544+
)?;
545+
546+
(input_ids, position_ids, input_lengths, Some(attention_bias))
537547
};
538548

539549
let attention_bias = if let Some(attn_bias) = attention_bias {
@@ -597,14 +607,16 @@ impl Qwen3Model {
597607
.iter()
598608
.map(|&i| {
599609
let i = i as usize;
610+
// With left padding, the last token is always at max_length - 1
600611
let last_token_idx = max_length - 1;
601612
outputs.i((i, last_token_idx))?.unsqueeze(0)
602613
})
603614
.collect();
604615

605616
Some(Tensor::cat(&results?, 0)?)
606617
} else {
607-
let last_idx = input_lengths[0] - 1;
618+
// For single inference, use the actual last token position from cumulative_seq_lengths
619+
let last_idx = batch.cumulative_seq_lengths[1] as usize - 1;
608620
Some(outputs.i((0, last_idx))?.unsqueeze(0)?)
609621
}
610622
}
@@ -617,7 +629,9 @@ impl Qwen3Model {
617629
let i = i as usize;
618630
let length = input_lengths[i];
619631

620-
let embeddings = outputs.i((i, ..length))?;
632+
// With left padding, actual tokens are at the end
633+
let padding = max_length - length;
634+
let embeddings = outputs.i((i, padding..))?;
621635
let sum = embeddings.sum_keepdim(0)?;
622636
sum / (length as f64)
623637
})

0 commit comments

Comments
 (0)