Skip to content

Commit 4e3a0bc

Browse files
authored
Fixing metal backend. (#655)
1 parent 26fe7d7 commit 4e3a0bc

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,16 +451,17 @@ impl Qwen3Model {
451451
.flat_map(|i| (0..seq_len).map(move |j| (j > i) as u8))
452452
.collect();
453453

454-
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), &Device::Cpu)?;
454+
let device = attention_bias.device();
455+
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
455456
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;
456457

457458
let negatives =
458-
Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?.to_dtype(self.dtype)?;
459+
Tensor::full(f32::MIN, attention_bias.shape(), device)?.to_dtype(self.dtype)?;
459460
let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?;
460461

461462
let causal_mask = causal_mask
462463
.where_cond(&negatives, &zeros)?
463-
.to_device(&self.device)?;
464+
.to_device(device)?;
464465

465466
attention_bias.broadcast_add(&causal_mask)
466467
}

0 commit comments

Comments
 (0)