@@ -12,6 +12,9 @@ struct Qwen3Attention {
12
12
v_proj : Linear ,
13
13
o_proj : Linear ,
14
14
15
+ q_norm : RMSNorm ,
16
+ k_norm : RMSNorm ,
17
+
15
18
num_attention_heads : usize ,
16
19
num_key_value_heads : usize ,
17
20
attention_head_size : usize ,
@@ -68,6 +71,8 @@ impl Qwen3Attention {
68
71
k_proj,
69
72
v_proj,
70
73
o_proj,
74
+ q_norm,
75
+ k_norm,
71
76
num_attention_heads,
72
77
num_key_value_heads,
73
78
attention_head_size,
@@ -94,13 +99,31 @@ impl Qwen3Attention {
94
99
let input_dims = hidden_states. dims ( ) ;
95
100
let input_shape = & input_dims[ ..input_dims. len ( ) - 1 ] ;
96
101
97
- let q = q. reshape ( [ input_shape, & [ self . num_attention_heads , self . head_dim ] ] . concat ( ) ) ?;
98
- let k = k. reshape ( [ input_shape, & [ self . num_key_value_heads , self . head_dim ] ] . concat ( ) ) ?;
99
- let v = v. reshape ( [ input_shape, & [ self . num_key_value_heads , self . head_dim ] ] . concat ( ) ) ?;
102
+ let q = q. reshape (
103
+ [
104
+ input_shape,
105
+ & [ self . num_attention_heads , self . attention_head_size ] ,
106
+ ]
107
+ . concat ( ) ,
108
+ ) ?;
109
+ let k = k. reshape (
110
+ [
111
+ input_shape,
112
+ & [ self . num_key_value_heads , self . attention_head_size ] ,
113
+ ]
114
+ . concat ( ) ,
115
+ ) ?;
116
+ let v = v. reshape (
117
+ [
118
+ input_shape,
119
+ & [ self . num_key_value_heads , self . attention_head_size ] ,
120
+ ]
121
+ . concat ( ) ,
122
+ ) ?;
100
123
101
124
// Apply normalization layers
102
- let q = self . q_norm . forward ( q ) ?;
103
- let k = self . k_norm . forward ( k ) ?;
125
+ let ( q , _res ) = self . q_norm . forward ( & q , None ) ?;
126
+ let ( k , _res ) = self . k_norm . forward ( & k , None ) ?;
104
127
105
128
// Transpose to [batch, heads, seq_len, head_dim]
106
129
let q = q. transpose ( 1 , 2 ) ?;
0 commit comments