@@ -487,21 +487,23 @@ impl Qwen3Model {
487
487
let seq_length = end - start;
488
488
input_lengths. push ( seq_length) ;
489
489
490
- for j in start..end {
491
- input_ids. push ( batch. input_ids [ j] ) ;
492
- position_ids. push ( batch. position_ids [ j] ) ;
493
- attention_bias. push ( 0.0 ) ;
494
- }
495
-
490
+ // Left padding for Qwen3-Embedding (pad at the beginning)
496
491
let padding = max_length - seq_length;
497
492
if padding > 0 {
498
493
masking = true ;
499
494
for _ in 0 ..padding {
500
- input_ids. insert ( start , self . pad_token_id ) ;
501
- position_ids. insert ( start , 0 ) ;
502
- attention_bias. insert ( start , f32:: MIN ) ;
495
+ input_ids. push ( self . pad_token_id ) ;
496
+ position_ids. push ( 0 ) ;
497
+ attention_bias. push ( f32:: MIN ) ;
503
498
}
504
499
}
500
+
501
+ // Then add the actual sequence
502
+ for j in start..end {
503
+ input_ids. push ( batch. input_ids [ j] ) ;
504
+ position_ids. push ( batch. position_ids [ j] ) ;
505
+ attention_bias. push ( 0.0 ) ;
506
+ }
505
507
}
506
508
507
509
let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
@@ -533,7 +535,15 @@ impl Qwen3Model {
533
535
) ?;
534
536
let input_lengths = vec ! [ batch. input_ids. len( ) ] ;
535
537
536
- ( input_ids, position_ids, input_lengths, None )
538
+ let seq_len = batch. input_ids . len ( ) ;
539
+ // Create attention bias for causal masking even for single sequences
540
+ let attention_bias = Tensor :: zeros (
541
+ ( 1 , self . num_attention_heads , seq_len, seq_len) ,
542
+ candle:: DType :: F32 ,
543
+ & self . device ,
544
+ ) ?;
545
+
546
+ ( input_ids, position_ids, input_lengths, Some ( attention_bias) )
537
547
} ;
538
548
539
549
let attention_bias = if let Some ( attn_bias) = attention_bias {
@@ -597,14 +607,16 @@ impl Qwen3Model {
597
607
. iter ( )
598
608
. map ( |& i| {
599
609
let i = i as usize ;
610
+ // With left padding, the last token is always at max_length - 1
600
611
let last_token_idx = max_length - 1 ;
601
612
outputs. i ( ( i, last_token_idx) ) ?. unsqueeze ( 0 )
602
613
} )
603
614
. collect ( ) ;
604
615
605
616
Some ( Tensor :: cat ( & results?, 0 ) ?)
606
617
} else {
607
- let last_idx = input_lengths[ 0 ] - 1 ;
618
+ // For single inference, use the actual last token position from cumulative_seq_lengths
619
+ let last_idx = batch. cumulative_seq_lengths [ 1 ] as usize - 1 ;
608
620
Some ( outputs. i ( ( 0 , last_idx) ) ?. unsqueeze ( 0 ) ?)
609
621
}
610
622
}
@@ -617,7 +629,9 @@ impl Qwen3Model {
617
629
let i = i as usize ;
618
630
let length = input_lengths[ i] ;
619
631
620
- let embeddings = outputs. i ( ( i, ..length) ) ?;
632
+ // With left padding, actual tokens are at the end
633
+ let padding = max_length - length;
634
+ let embeddings = outputs. i ( ( i, padding..) ) ?;
621
635
let sum = embeddings. sum_keepdim ( 0 ) ?;
622
636
sum / ( length as f64 )
623
637
} )
0 commit comments