Skip to content

Commit fe46d98

Browse files
committed
Add head_dim in Qwen3Config and udpate Qwen3Attention
1 parent 57c4490 commit fe46d98

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@ impl Qwen3Attention {
3131
}
3232

3333
let num_attention_heads = config.num_attention_heads;
34-
let attention_head_size = config.hidden_size / config.num_attention_heads;
34+
let attention_head_size = config
35+
.head_dim
36+
.unwrap_or(config.hidden_size / config.num_attention_heads);
3537
let num_key_value_heads = config.num_key_value_heads;
3638
let hidden_size = config.hidden_size;
3739

38-
let query_weight = vb.pp("q_proj").get((hidden_size, hidden_size), "weight")?;
40+
let query_weight = vb.pp("q_proj").get(
41+
(num_attention_heads * attention_head_size, hidden_size),
42+
"weight",
43+
)?;
3944
let query_bias = vb.pp("q_proj").get(hidden_size, "bias")?;
4045
let q_proj = Linear::new(query_weight, Some(query_bias), None);
4146

@@ -57,8 +62,10 @@ impl Qwen3Attention {
5762
.get(num_key_value_heads * attention_head_size, "bias")?;
5863
let v_proj = Linear::new(value_weight, Some(value_bias), None);
5964

60-
let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?;
61-
65+
let o_proj_weight = vb.pp("o_proj").get(
66+
(num_attention_heads * attention_head_size, hidden_size),
67+
"weight",
68+
)?;
6269
let o_proj = Linear::new(o_proj_weight, None, None);
6370

6471
let q_norm = RMSNorm::load(vb.pp("q_norm"), attention_head_size, config.rms_norm_eps)?;

backends/candle/src/models/qwen3.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::Deserialize;
44
#[derive(Debug, Clone, PartialEq, Deserialize)]
55
pub struct Qwen3Config {
66
pub vocab_size: usize,
7+
pub head_dim: Option<usize>,
78
pub hidden_size: usize,
89
pub intermediate_size: usize,
910
pub num_hidden_layers: usize,

0 commit comments

Comments
 (0)