@@ -31,11 +31,16 @@ impl Qwen3Attention {
31
31
}
32
32
33
33
let num_attention_heads = config. num_attention_heads ;
34
- let attention_head_size = config. hidden_size / config. num_attention_heads ;
34
+ let attention_head_size = config
35
+ . head_dim
36
+ . unwrap_or ( config. hidden_size / config. num_attention_heads ) ;
35
37
let num_key_value_heads = config. num_key_value_heads ;
36
38
let hidden_size = config. hidden_size ;
37
39
38
- let query_weight = vb. pp ( "q_proj" ) . get ( ( hidden_size, hidden_size) , "weight" ) ?;
40
+ let query_weight = vb. pp ( "q_proj" ) . get (
41
+ ( num_attention_heads * attention_head_size, hidden_size) ,
42
+ "weight" ,
43
+ ) ?;
39
44
let query_bias = vb. pp ( "q_proj" ) . get ( hidden_size, "bias" ) ?;
40
45
let q_proj = Linear :: new ( query_weight, Some ( query_bias) , None ) ;
41
46
@@ -57,8 +62,10 @@ impl Qwen3Attention {
57
62
. get ( num_key_value_heads * attention_head_size, "bias" ) ?;
58
63
let v_proj = Linear :: new ( value_weight, Some ( value_bias) , None ) ;
59
64
60
- let o_proj_weight = vb. pp ( "o_proj" ) . get ( ( hidden_size, hidden_size) , "weight" ) ?;
61
-
65
+ let o_proj_weight = vb. pp ( "o_proj" ) . get (
66
+ ( num_attention_heads * attention_head_size, hidden_size) ,
67
+ "weight" ,
68
+ ) ?;
62
69
let o_proj = Linear :: new ( o_proj_weight, None , None ) ;
63
70
64
71
let q_norm = RMSNorm :: load ( vb. pp ( "q_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
0 commit comments