Skip to content

Commit 57c4490

Browse files
committed
Fix RMSNorm forward pass
1 parent d29da35 commit 57c4490

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ struct Qwen3Attention {
1212
v_proj: Linear,
1313
o_proj: Linear,
1414

15+
q_norm: RMSNorm,
16+
k_norm: RMSNorm,
17+
1518
num_attention_heads: usize,
1619
num_key_value_heads: usize,
1720
attention_head_size: usize,
@@ -68,6 +71,8 @@ impl Qwen3Attention {
6871
k_proj,
6972
v_proj,
7073
o_proj,
74+
q_norm,
75+
k_norm,
7176
num_attention_heads,
7277
num_key_value_heads,
7378
attention_head_size,
@@ -94,13 +99,31 @@ impl Qwen3Attention {
9499
let input_dims = hidden_states.dims();
95100
let input_shape = &input_dims[..input_dims.len() - 1];
96101

97-
let q = q.reshape([input_shape, &[self.num_attention_heads, self.head_dim]].concat())?;
98-
let k = k.reshape([input_shape, &[self.num_key_value_heads, self.head_dim]].concat())?;
99-
let v = v.reshape([input_shape, &[self.num_key_value_heads, self.head_dim]].concat())?;
102+
let q = q.reshape(
103+
[
104+
input_shape,
105+
&[self.num_attention_heads, self.attention_head_size],
106+
]
107+
.concat(),
108+
)?;
109+
let k = k.reshape(
110+
[
111+
input_shape,
112+
&[self.num_key_value_heads, self.attention_head_size],
113+
]
114+
.concat(),
115+
)?;
116+
let v = v.reshape(
117+
[
118+
input_shape,
119+
&[self.num_key_value_heads, self.attention_head_size],
120+
]
121+
.concat(),
122+
)?;
100123

101124
// Apply normalization layers
102-
let q = self.q_norm.forward(q)?;
103-
let k = self.k_norm.forward(k)?;
125+
let (q, _res) = self.q_norm.forward(&q, None)?;
126+
let (k, _res) = self.k_norm.forward(&k, None)?;
104127

105128
// Transpose to [batch, heads, seq_len, head_dim]
106129
let q = q.transpose(1, 2)?;

0 commit comments

Comments
 (0)