2
2
import torch
3
3
import torch .nn as nn
4
4
5
- from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , MultiQueryAttentionV2
5
+ from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d , MultiQueryAttentionV2
6
6
7
7
import importlib
8
8
import os
@@ -120,6 +120,7 @@ def test_get_act_fn_none():
120
120
assert get_act_fn (None ) is None
121
121
assert get_act_fn ('' ) is None
122
122
123
+
123
124
@pytest .mark .parametrize ("dim" , [128 ])
124
125
@pytest .mark .parametrize ("dim_out" , [128 , 256 ])
125
126
@pytest .mark .parametrize ("use_m" , [True , False ])
@@ -134,4 +135,26 @@ def test_mqa_v2(dim, dim_out, use_m):
134
135
135
136
y = mqa (x , m = m )
136
137
137
- assert (y .shape ) == (1 , dim_out , 32 , 48 )
138
+ assert (y .shape ) == (1 , dim_out , 32 , 48 )
139
+
140
+
141
+ @pytest .mark .parametrize ("bias" , [True , False ])
142
+ @pytest .mark .parametrize ("expand_first" , [True , False ])
143
+ @pytest .mark .parametrize ("head_first" , [True , False ])
144
+ @pytest .mark .parametrize ("attn_mask" , [True , False ])
145
+ def test_attn2d (bias , expand_first , head_first , attn_mask ):
146
+ x = torch .randn (1 , 128 , 32 , 48 )
147
+ attn = Attention2d (
148
+ 128 , 128 , num_heads = 4 , bias = bias , expand_first = expand_first , head_first = head_first
149
+ )
150
+
151
+ if attn_mask :
152
+ mask = torch .randint (0 , 1 , size = (32 * 48 , 32 * 48 ), dtype = torch .float32 )
153
+ else :
154
+ mask = None
155
+
156
+ o1 = attn (x , mask )
157
+ attn .fused_attn = False
158
+ o2 = attn (x , mask )
159
+
160
+ assert torch .allclose (o1 , o2 , atol = 1e-5 ), f"{ torch .abs (o1 - o2 ).max ()} "
0 commit comments