1
+ import itertools
1
2
import unittest
2
3
from collections import defaultdict
3
4
@@ -19,13 +20,18 @@ def setUp(self):
19
20
torch .manual_seed (42 )
20
21
21
22
def test_without_cache (self ):
22
- def test (use_qk_norm , use_conv2d ):
23
+ def test (use_qk_norm , qk_norm_before_rope , adopt_hf_rope , use_conv2d ):
24
+ if not use_qk_norm and qk_norm_before_rope :
25
+ # Redundant test.
26
+ return
27
+
23
28
config = ModelArgs (
24
29
dim = 64 ,
25
30
n_heads = 4 ,
26
31
n_kv_heads = 2 ,
27
32
max_seq_len = 8 ,
28
33
use_qk_norm = use_qk_norm ,
34
+ qk_norm_before_rope = qk_norm_before_rope ,
29
35
)
30
36
layer_id = 0
31
37
rope = Rope (config )
@@ -40,12 +46,19 @@ def test(use_qk_norm, use_conv2d):
40
46
torch .rand (config .head_dim ) * 0.2 + 0.9
41
47
)
42
48
static_attn .load_weights_from_attention_mha (attn_mha )
49
+ if adopt_hf_rope :
50
+ static_attn .adopt_hf_rope ()
43
51
if use_conv2d :
44
52
static_attn .linear_to_conv2d ()
45
53
46
54
x = torch .rand (1 , config .max_seq_len , config .dim )
47
55
freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
48
56
expected , _ = attn_mha (x , freqs_cos , freqs_sin )
57
+
58
+ if adopt_hf_rope :
59
+ config .use_hf_rope = True
60
+ rope = Rope (config )
61
+ freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
49
62
mask = torch .triu (
50
63
torch .full ((1 , config .max_seq_len , config .max_seq_len ), float ("-inf" )),
51
64
diagonal = 1 ,
@@ -56,45 +69,16 @@ def test(use_qk_norm, use_conv2d):
56
69
freqs_sin ,
57
70
mask = mask ,
58
71
)
59
- self .assertTrue (torch .isclose (y , expected , rtol = 1e-3 ).all ())
60
-
61
- test (True , True )
62
- test (True , False )
63
- test (False , True )
64
- test (False , False )
65
-
66
- def test_hf_rope_without_cache (self ):
67
- config = ModelArgs (
68
- dim = 64 ,
69
- n_heads = 4 ,
70
- n_kv_heads = 2 ,
71
- max_seq_len = 8 ,
72
- use_qk_norm = True ,
73
- use_hf_rope = True ,
74
- )
75
- layer_id = 0
76
- rope = Rope (config )
77
- attn_mha = AttentionMHA (config , layer_id , rope ).eval ()
78
- with torch .no_grad ():
79
- attn_mha .q_norm_fn .weight .copy_ (torch .rand (config .head_dim ) * 0.2 + 0.9 )
80
- attn_mha .k_norm_fn .weight .copy_ (torch .rand (config .head_dim ) * 0.2 + 0.9 )
81
- static_attn = StaticAttention (config , layer_id , rope ).eval ()
82
- static_attn .load_weights_from_attention_mha (attn_mha )
72
+ self .assertTrue (
73
+ torch .isclose (y , expected , rtol = 1e-3 ).all (),
74
+ f"Failed for use_qk_norm={ use_qk_norm } , "
75
+ f"qk_norm_before_rope={ qk_norm_before_rope } , "
76
+ f"adopt_hf_rope={ adopt_hf_rope } , "
77
+ f"use_conv2d={ use_conv2d } " ,
78
+ )
83
79
84
- x = torch .rand (1 , config .max_seq_len , config .dim )
85
- freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
86
- expected , _ = attn_mha (x , freqs_cos , freqs_sin )
87
- mask = torch .triu (
88
- torch .full ((1 , config .max_seq_len , config .max_seq_len ), float ("-inf" )),
89
- diagonal = 1 ,
90
- )
91
- y , _ = static_attn (
92
- x ,
93
- freqs_cos .unsqueeze (0 ),
94
- freqs_sin .unsqueeze (0 ),
95
- mask = mask ,
96
- )
97
- self .assertTrue (torch .isclose (y , expected , rtol = 1e-3 ).all ())
80
+ for args in itertools .product ([False , True ], repeat = 4 ):
81
+ test (* args )
98
82
99
83
def test_with_cache (self ):
100
84
config = ModelArgs (
@@ -108,6 +92,7 @@ def test_with_cache(self):
108
92
attn_mha = AttentionMHA (config , layer_id , rope ).eval ()
109
93
static_attn = StaticAttention (config , layer_id , rope ).eval ()
110
94
static_attn .load_weights_from_attention_mha (attn_mha )
95
+ static_attn .adopt_hf_rope ()
111
96
112
97
x = torch .rand (1 , config .max_seq_len , config .dim )
113
98
freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
@@ -117,6 +102,10 @@ def test_with_cache(self):
117
102
chunk_len = config .max_seq_len // n_chunks
118
103
cache_len = config .max_seq_len - chunk_len
119
104
105
+ config .use_hf_rope = True
106
+ hf_rope = Rope (config )
107
+ hf_freqs_cos , hf_freqs_sin = hf_rope .get_freqs (None , config .max_seq_len )
108
+
120
109
def test_with_style (style ):
121
110
mask = StaticAttentionMask (chunk_len , cache_len , style = style )
122
111
mask .tensor [:, :, cache_len :] = torch .triu (
@@ -139,8 +128,8 @@ def test_with_style(style):
139
128
for i in range (n_chunks ):
140
129
y_i , attn_update = static_attn (
141
130
x [:, i * chunk_len : (i + 1 ) * chunk_len , :],
142
- freqs_cos [i * chunk_len : (i + 1 ) * chunk_len ],
143
- freqs_sin [i * chunk_len : (i + 1 ) * chunk_len ],
131
+ hf_freqs_cos [i * chunk_len : (i + 1 ) * chunk_len ],
132
+ hf_freqs_sin [i * chunk_len : (i + 1 ) * chunk_len ],
144
133
mask = mask .tensor ,
145
134
in_cache_state = (k_caches , v_caches ),
146
135
out_cache_state = ({}, {}),
@@ -175,6 +164,7 @@ def _get_test_transformers(self, config):
175
164
mha_transformer .layers , static_transformer .layers
176
165
):
177
166
static_layer .attention .load_weights_from_attention_mha (mha_layer .attention )
167
+ static_layer .attention .adopt_hf_rope ()
178
168
179
169
return mha_transformer , static_transformer
180
170
@@ -196,6 +186,7 @@ def test_within_transformer(self):
196
186
cache_len = config .max_seq_len - chunk_len
197
187
198
188
def test_with_style (style ):
189
+ config .use_hf_rope = True
199
190
mgr = StaticAttentionIOManager (config , chunk_len , cache_len , style = style )
200
191
ys = []
201
192
for i in range (n_chunks ):
@@ -222,6 +213,7 @@ def test_lookahead_decode(self):
222
213
)
223
214
_ , static_transformer = self ._get_test_transformers (config )
224
215
216
+ config .use_hf_rope = True
225
217
input_len = 32
226
218
cache_len = config .max_seq_len - input_len
227
219
prefill_input = torch .randint (config .vocab_size , (input_len ,))
0 commit comments