File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
backends/candle/src/models Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -451,16 +451,17 @@ impl Qwen3Model {
451
451
. flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
452
452
. collect ( ) ;
453
453
454
- let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454
+ let device = attention_bias. device ( ) ;
455
+ let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , device) ?;
455
456
let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
456
457
457
458
let negatives =
458
- Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?. to_dtype ( self . dtype ) ?;
459
+ Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , device ) ?. to_dtype ( self . dtype ) ?;
459
460
let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
460
461
461
462
let causal_mask = causal_mask
462
463
. where_cond ( & negatives, & zeros) ?
463
- . to_device ( & self . device ) ?;
464
+ . to_device ( device) ?;
464
465
465
466
attention_bias. broadcast_add ( & causal_mask)
466
467
}
You can’t perform that action at this time.
0 commit comments