Skip to content

Commit 2d5277e

Browse files
authored
Merge branch 'main' into fix-mqa-v2
2 parents 6171e75 + 2d734d9 commit 2d5277e

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

tests/test_layers.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44

5-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, MultiQueryAttentionV2
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
66

77
import importlib
88
import os
@@ -120,6 +120,7 @@ def test_get_act_fn_none():
120120
assert get_act_fn(None) is None
121121
assert get_act_fn('') is None
122122

123+
123124
@pytest.mark.parametrize("dim", [128])
124125
@pytest.mark.parametrize("dim_out", [128, 256])
125126
@pytest.mark.parametrize("use_m", [True, False])
@@ -134,4 +135,26 @@ def test_mqa_v2(dim, dim_out, use_m):
134135

135136
y = mqa(x, m=m)
136137

137-
assert (y.shape) == (1, dim_out, 32, 48)
138+
assert (y.shape) == (1, dim_out, 32, 48)
139+
140+
141+
@pytest.mark.parametrize("bias", [True, False])
142+
@pytest.mark.parametrize("expand_first", [True, False])
143+
@pytest.mark.parametrize("head_first", [True, False])
144+
@pytest.mark.parametrize("attn_mask", [True, False])
145+
def test_attn2d(bias, expand_first, head_first, attn_mask):
146+
x = torch.randn(1, 128, 32, 48)
147+
attn = Attention2d(
148+
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
149+
)
150+
151+
if attn_mask:
152+
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
153+
else:
154+
mask = None
155+
156+
o1 = attn(x, mask)
157+
attn.fused_attn = False
158+
o2 = attn(x, mask)
159+
160+
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"

timm/layers/attention2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def __init__(
312312
self.num_heads = num_heads
313313
self.dim_head = dim_attn // num_heads
314314
self.head_first = head_first
315-
self.scale = num_heads ** -0.5
316315
self.fused_attn = use_fused_attn()
317316

318317
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
@@ -337,14 +336,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
337336
dropout_p=self.attn_drop.p if self.training else 0.,
338337
).transpose(-1, -2).reshape(B, -1, H, W)
339338
else:
340-
q = q * self.scale
341-
attn = q.transpose(-2, -1) @ k
339+
q = q.transpose(-1, -2)
340+
v = v.transpose(-1, -2)
341+
attn = q @ k * q.size(-1) ** -0.5
342342
if attn_mask is not None:
343343
# NOTE: assumes mask is float and in correct shape
344344
attn = attn + attn_mask
345345
attn = attn.softmax(dim=-1)
346346
attn = self.attn_drop(attn)
347-
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
347+
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
348348

349349
x = self.proj(x)
350350
x = self.proj_drop(x)

0 commit comments

Comments
 (0)