Skip to content

Commit a5907fa

Browse files
committed
Add attention_bias and handle in Qwen3Attention
1 parent fe46d98 commit a5907fa

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,40 @@ impl Qwen3Attention {
4141
(num_attention_heads * attention_head_size, hidden_size),
4242
"weight",
4343
)?;
44-
let query_bias = vb.pp("q_proj").get(hidden_size, "bias")?;
45-
let q_proj = Linear::new(query_weight, Some(query_bias), None);
44+
let query_bias = if config.attention_bias {
45+
Some(vb.pp("q_proj").get(hidden_size, "bias")?)
46+
} else {
47+
None
48+
};
49+
let q_proj = Linear::new(query_weight, query_bias, None);
4650

4751
let key_weight = vb.pp("k_proj").get(
4852
(num_key_value_heads * attention_head_size, hidden_size),
4953
"weight",
5054
)?;
51-
let key_bias = vb
52-
.pp("k_proj")
53-
.get(num_key_value_heads * attention_head_size, "bias")?;
54-
let k_proj = Linear::new(key_weight, Some(key_bias), None);
55+
let key_bias = if config.attention_bias {
56+
Some(
57+
vb.pp("k_proj")
58+
.get(num_key_value_heads * attention_head_size, "bias")?,
59+
)
60+
} else {
61+
None
62+
};
63+
let k_proj = Linear::new(key_weight, key_bias, None);
5564

5665
let value_weight = vb.pp("v_proj").get(
5766
(num_key_value_heads * attention_head_size, hidden_size),
5867
"weight",
5968
)?;
60-
let value_bias = vb
61-
.pp("v_proj")
62-
.get(num_key_value_heads * attention_head_size, "bias")?;
63-
let v_proj = Linear::new(value_weight, Some(value_bias), None);
69+
let value_bias = if config.attention_bias {
70+
Some(
71+
vb.pp("v_proj")
72+
.get(num_key_value_heads * attention_head_size, "bias")?,
73+
)
74+
} else {
75+
None
76+
};
77+
let v_proj = Linear::new(value_weight, value_bias, None);
6478

6579
let o_proj_weight = vb.pp("o_proj").get(
6680
(num_attention_heads * attention_head_size, hidden_size),

backends/candle/src/models/qwen3.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use serde::Deserialize;
33

44
#[derive(Debug, Clone, PartialEq, Deserialize)]
55
pub struct Qwen3Config {
6+
pub attention_bias: bool,
67
pub vocab_size: usize,
78
pub head_dim: Option<usize>,
89
pub hidden_size: usize,

0 commit comments

Comments
 (0)