Skip to content

Commit 5c91435

Browse files
authored
Fix static attention non-HF RoPE implementation
Differential Revision: D76951243 Pull Request resolved: #11808
1 parent aae0dba commit 5c91435

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

examples/models/llama/static_attention.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def forward(
352352
x_r, x_i = x[..., ::2], x[..., 1::2]
353353
x_out_r = x_r * freqs_cos - x_i * freqs_sin
354354
x_out_i = x_r * freqs_sin + x_i * freqs_cos
355-
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
355+
x_out = torch.stack([x_out_r, x_out_i], dim=-1).flatten(2)
356356
return x_out
357357

358358

@@ -378,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
378378
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
379379
self.attention_qkv_bias = config.attention_qkv_bias
380380
self.use_qk_norm = config.use_qk_norm
381+
self.qk_norm_before_rope = config.qk_norm_before_rope
381382
self.use_conv2d = False
382383

383384
self.wqs = nn.ModuleList(
@@ -449,12 +450,17 @@ def from_conv2ds(ts):
449450
new_ks = from_conv2ds(new_ks)
450451
new_vs = from_conv2ds(new_vs)
451452

452-
if self.use_qk_norm:
453+
if self.use_qk_norm and self.qk_norm_before_rope:
453454
new_qs = [self.q_norm(q) for q in new_qs]
454455
new_ks = [self.k_norm(k) for k in new_ks]
455456

456457
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
457458
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
459+
460+
if self.use_qk_norm and not self.qk_norm_before_rope:
461+
new_qs = [self.q_norm(q) for q in new_qs]
462+
new_ks = [self.k_norm(k) for k in new_ks]
463+
458464
all_ks = []
459465
all_vs = []
460466
for i in range(self.n_kv_heads):
@@ -505,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
505511

506512
if other.use_qk_norm:
507513
self.use_qk_norm = True
514+
self.qk_norm_before_rope = other.qk_norm_before_rope
508515
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
509516
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
510517
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)

examples/models/llama/tests/test_static_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def test(use_qk_norm, use_conv2d):
3030
rope = Rope(config)
3131
attn_mha = AttentionMHA(config, layer_id, rope).eval()
3232
static_attn = StaticAttention(config, layer_id, rope).eval()
33+
if use_qk_norm:
34+
with torch.no_grad():
35+
attn_mha.q_norm_fn.weight.copy_(
36+
torch.rand(config.head_dim) * 0.2 + 0.9
37+
)
38+
attn_mha.k_norm_fn.weight.copy_(
39+
torch.rand(config.head_dim) * 0.2 + 0.9
40+
)
3341
static_attn.load_weights_from_attention_mha(attn_mha)
3442
if use_conv2d:
3543
static_attn.linear_to_conv2d()
@@ -60,11 +68,15 @@ def test_hf_rope_without_cache(self):
6068
n_heads=4,
6169
n_kv_heads=2,
6270
max_seq_len=8,
71+
use_qk_norm=True,
6372
use_hf_rope=True,
6473
)
6574
layer_id = 0
6675
rope = Rope(config)
6776
attn_mha = AttentionMHA(config, layer_id, rope).eval()
77+
with torch.no_grad():
78+
attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
79+
attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
6880
static_attn = StaticAttention(config, layer_id, rope).eval()
6981
static_attn.load_weights_from_attention_mha(attn_mha)
7082

0 commit comments

Comments
 (0)