Skip to content

Commit 9e05d89

Browse files
authored
Source transform for HF RoPE in static attention
Differential Revision: D78353775 Pull Request resolved: #12500
1 parent 63bad47 commit 9e05d89

File tree

2 files changed

+61
-41
lines changed

2 files changed

+61
-41
lines changed

examples/models/llama/static_attention.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,34 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
730730
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
731731
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())
732732

733+
def adopt_hf_rope(self):
734+
if self.rope.use_hf_rope:
735+
return
736+
737+
if self.use_conv2d:
738+
raise RuntimeError(
739+
"adopt_hf_rope needs to be called before linear_to_conv2d"
740+
)
741+
742+
# Permute weights of qk projections and norms to match HF RoPE's channel order.
743+
def permute(w):
744+
shape = w.shape
745+
return (
746+
w.view((-1, 2) + shape[1:]).transpose(0, 1).reshape(shape).contiguous()
747+
)
748+
749+
for wq in self.wqs:
750+
wq.weight.data.copy_(permute(wq.weight.data))
751+
752+
for wk in self.wks:
753+
wk.weight.data.copy_(permute(wk.weight.data))
754+
755+
if self.use_qk_norm:
756+
self.q_norm.weight.data.copy_(permute(self.q_norm.weight.data))
757+
self.k_norm.weight.data.copy_(permute(self.k_norm.weight.data))
758+
759+
self.rope.use_hf_rope = True
760+
733761
def linear_to_conv2d(self):
734762
def transfer_weight(linear, conv2d):
735763
conv2d.weight.data.copy_(linear.weight[:, :, None, None])

examples/models/llama/tests/test_static_attention.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import unittest
23
from collections import defaultdict
34

@@ -19,13 +20,18 @@ def setUp(self):
1920
torch.manual_seed(42)
2021

2122
def test_without_cache(self):
22-
def test(use_qk_norm, use_conv2d):
23+
def test(use_qk_norm, qk_norm_before_rope, adopt_hf_rope, use_conv2d):
24+
if not use_qk_norm and qk_norm_before_rope:
25+
# Redundant test.
26+
return
27+
2328
config = ModelArgs(
2429
dim=64,
2530
n_heads=4,
2631
n_kv_heads=2,
2732
max_seq_len=8,
2833
use_qk_norm=use_qk_norm,
34+
qk_norm_before_rope=qk_norm_before_rope,
2935
)
3036
layer_id = 0
3137
rope = Rope(config)
@@ -40,12 +46,19 @@ def test(use_qk_norm, use_conv2d):
4046
torch.rand(config.head_dim) * 0.2 + 0.9
4147
)
4248
static_attn.load_weights_from_attention_mha(attn_mha)
49+
if adopt_hf_rope:
50+
static_attn.adopt_hf_rope()
4351
if use_conv2d:
4452
static_attn.linear_to_conv2d()
4553

4654
x = torch.rand(1, config.max_seq_len, config.dim)
4755
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
4856
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
57+
58+
if adopt_hf_rope:
59+
config.use_hf_rope = True
60+
rope = Rope(config)
61+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
4962
mask = torch.triu(
5063
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
5164
diagonal=1,
@@ -56,45 +69,16 @@ def test(use_qk_norm, use_conv2d):
5669
freqs_sin,
5770
mask=mask,
5871
)
59-
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
60-
61-
test(True, True)
62-
test(True, False)
63-
test(False, True)
64-
test(False, False)
65-
66-
def test_hf_rope_without_cache(self):
67-
config = ModelArgs(
68-
dim=64,
69-
n_heads=4,
70-
n_kv_heads=2,
71-
max_seq_len=8,
72-
use_qk_norm=True,
73-
use_hf_rope=True,
74-
)
75-
layer_id = 0
76-
rope = Rope(config)
77-
attn_mha = AttentionMHA(config, layer_id, rope).eval()
78-
with torch.no_grad():
79-
attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
80-
attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
81-
static_attn = StaticAttention(config, layer_id, rope).eval()
82-
static_attn.load_weights_from_attention_mha(attn_mha)
72+
self.assertTrue(
73+
torch.isclose(y, expected, rtol=1e-3).all(),
74+
f"Failed for use_qk_norm={use_qk_norm}, "
75+
f"qk_norm_before_rope={qk_norm_before_rope}, "
76+
f"adopt_hf_rope={adopt_hf_rope}, "
77+
f"use_conv2d={use_conv2d}",
78+
)
8379

84-
x = torch.rand(1, config.max_seq_len, config.dim)
85-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
86-
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
87-
mask = torch.triu(
88-
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
89-
diagonal=1,
90-
)
91-
y, _ = static_attn(
92-
x,
93-
freqs_cos.unsqueeze(0),
94-
freqs_sin.unsqueeze(0),
95-
mask=mask,
96-
)
97-
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
80+
for args in itertools.product([False, True], repeat=4):
81+
test(*args)
9882

9983
def test_with_cache(self):
10084
config = ModelArgs(
@@ -108,6 +92,7 @@ def test_with_cache(self):
10892
attn_mha = AttentionMHA(config, layer_id, rope).eval()
10993
static_attn = StaticAttention(config, layer_id, rope).eval()
11094
static_attn.load_weights_from_attention_mha(attn_mha)
95+
static_attn.adopt_hf_rope()
11196

11297
x = torch.rand(1, config.max_seq_len, config.dim)
11398
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
@@ -117,6 +102,10 @@ def test_with_cache(self):
117102
chunk_len = config.max_seq_len // n_chunks
118103
cache_len = config.max_seq_len - chunk_len
119104

105+
config.use_hf_rope = True
106+
hf_rope = Rope(config)
107+
hf_freqs_cos, hf_freqs_sin = hf_rope.get_freqs(None, config.max_seq_len)
108+
120109
def test_with_style(style):
121110
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
122111
mask.tensor[:, :, cache_len:] = torch.triu(
@@ -139,8 +128,8 @@ def test_with_style(style):
139128
for i in range(n_chunks):
140129
y_i, attn_update = static_attn(
141130
x[:, i * chunk_len : (i + 1) * chunk_len, :],
142-
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
143-
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
131+
hf_freqs_cos[i * chunk_len : (i + 1) * chunk_len],
132+
hf_freqs_sin[i * chunk_len : (i + 1) * chunk_len],
144133
mask=mask.tensor,
145134
in_cache_state=(k_caches, v_caches),
146135
out_cache_state=({}, {}),
@@ -175,6 +164,7 @@ def _get_test_transformers(self, config):
175164
mha_transformer.layers, static_transformer.layers
176165
):
177166
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
167+
static_layer.attention.adopt_hf_rope()
178168

179169
return mha_transformer, static_transformer
180170

@@ -196,6 +186,7 @@ def test_within_transformer(self):
196186
cache_len = config.max_seq_len - chunk_len
197187

198188
def test_with_style(style):
189+
config.use_hf_rope = True
199190
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
200191
ys = []
201192
for i in range(n_chunks):
@@ -222,6 +213,7 @@ def test_lookahead_decode(self):
222213
)
223214
_, static_transformer = self._get_test_transformers(config)
224215

216+
config.use_hf_rope = True
225217
input_len = 32
226218
cache_len = config.max_seq_len - input_len
227219
prefill_input = torch.randint(config.vocab_size, (input_len,))

0 commit comments

Comments
 (0)