Skip to content

Commit f0fb471

Browse files
committed
Remove separate ConvNormActAa class, merge with ConvNormAct
1 parent 5efa15b commit f0fb471

File tree

5 files changed

+28
-103
lines changed

5 files changed

+28
-103
lines changed

timm/layers/conv_bn_act.py

Lines changed: 9 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def __init__(
2626
apply_norm: bool = True,
2727
apply_act: bool = True,
2828
norm_layer: LayerType = nn.BatchNorm2d,
29-
act_layer: LayerType = nn.ReLU,
29+
act_layer: Optional[LayerType] = nn.ReLU,
30+
aa_layer: Optional[LayerType] = None,
3031
drop_layer: Optional[Type[nn.Module]] = None,
3132
conv_kwargs: Optional[Dict[str, Any]] = None,
3233
norm_kwargs: Optional[Dict[str, Any]] = None,
@@ -36,12 +37,13 @@ def __init__(
3637
conv_kwargs = conv_kwargs or {}
3738
norm_kwargs = norm_kwargs or {}
3839
act_kwargs = act_kwargs or {}
40+
use_aa = aa_layer is not None and stride > 1
3941

4042
self.conv = create_conv2d(
4143
in_channels,
4244
out_channels,
4345
kernel_size,
44-
stride=stride,
46+
stride=1 if use_aa else stride,
4547
padding=padding,
4648
dilation=dilation,
4749
groups=groups,
@@ -67,6 +69,8 @@ def __init__(
6769
norm_kwargs['drop_layer'] = drop_layer
6870
self.bn.add_module('drop', drop_layer())
6971

72+
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None)
73+
7074
@property
7175
def in_channels(self):
7276
return self.conv.in_channels
@@ -78,79 +82,10 @@ def out_channels(self):
7882
def forward(self, x):
7983
x = self.conv(x)
8084
x = self.bn(x)
85+
if self.aa is not None:
86+
x = self.aa(x)
8187
return x
8288

8389

8490
ConvBnAct = ConvNormAct
85-
86-
87-
class ConvNormActAa(nn.Module):
88-
def __init__(
89-
self,
90-
in_channels: int,
91-
out_channels: int,
92-
kernel_size: int = 1,
93-
stride: int = 1,
94-
padding: PadType = '',
95-
dilation: int = 1,
96-
groups: int = 1,
97-
bias: bool = False,
98-
apply_norm: bool = True,
99-
apply_act: bool = True,
100-
norm_layer: LayerType = nn.BatchNorm2d,
101-
act_layer: LayerType = nn.ReLU,
102-
aa_layer: Optional[LayerType] = None,
103-
drop_layer: Optional[Type[nn.Module]] = None,
104-
conv_kwargs: Optional[Dict[str, Any]] = None,
105-
norm_kwargs: Optional[Dict[str, Any]] = None,
106-
act_kwargs: Optional[Dict[str, Any]] = None,
107-
):
108-
super(ConvNormActAa, self).__init__()
109-
use_aa = aa_layer is not None and stride == 2
110-
conv_kwargs = conv_kwargs or {}
111-
norm_kwargs = norm_kwargs or {}
112-
act_kwargs = act_kwargs or {}
113-
114-
self.conv = create_conv2d(
115-
in_channels, out_channels, kernel_size,
116-
stride=1 if use_aa else stride,
117-
padding=padding,
118-
dilation=dilation,
119-
groups=groups,
120-
bias=bias,
121-
**conv_kwargs,
122-
)
123-
124-
if apply_norm:
125-
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
126-
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
127-
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
128-
if drop_layer:
129-
norm_kwargs['drop_layer'] = drop_layer
130-
self.bn = norm_act_layer(
131-
out_channels,
132-
apply_act=apply_act,
133-
act_kwargs=act_kwargs,
134-
**norm_kwargs,
135-
)
136-
else:
137-
self.bn = nn.Sequential()
138-
if drop_layer:
139-
norm_kwargs['drop_layer'] = drop_layer
140-
self.bn.add_module('drop', drop_layer())
141-
142-
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
143-
144-
@property
145-
def in_channels(self):
146-
return self.conv.in_channels
147-
148-
@property
149-
def out_channels(self):
150-
return self.conv.out_channels
151-
152-
def forward(self, x):
153-
x = self.conv(x)
154-
x = self.bn(x)
155-
x = self.aa(x)
156-
return x
91+
ConvNormActAa = ConvNormAct # backwards compat, when they were separate

timm/layers/selective_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn as nn
99

10-
from .conv_bn_act import ConvNormActAa
10+
from .conv_bn_act import ConvNormAct
1111
from .helpers import make_divisible
1212
from .trace_utils import _assert
1313

@@ -100,7 +100,7 @@ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, d
100100
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
101101
aa_layer=aa_layer, drop_layer=drop_layer)
102102
self.paths = nn.ModuleList([
103-
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
103+
ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
104104
for k, d in zip(kernel_size, dilation)])
105105

106106
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)

timm/models/_efficientnet_blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.nn import functional as F
1010

1111
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
12-
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
12+
ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
1313

1414
__all__ = [
1515
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
@@ -345,7 +345,7 @@ def __init__(
345345
if dw_kernel_size_start:
346346
dw_start_stride = stride if not dw_kernel_size_mid else 1
347347
dw_start_groups = num_groups(group_size, in_chs)
348-
self.dw_start = ConvNormActAa(
348+
self.dw_start = ConvNormAct(
349349
in_chs, in_chs, dw_kernel_size_start,
350350
stride=dw_start_stride,
351351
dilation=dilation, # FIXME
@@ -373,7 +373,7 @@ def __init__(
373373
# Middle depth-wise convolution
374374
if dw_kernel_size_mid:
375375
groups = num_groups(group_size, mid_chs)
376-
self.dw_mid = ConvNormActAa(
376+
self.dw_mid = ConvNormAct(
377377
mid_chs, mid_chs, dw_kernel_size_mid,
378378
stride=stride,
379379
dilation=dilation, # FIXME

timm/models/cspnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn as nn
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23-
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
23+
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
2424
from ._builder import build_model_with_cfg
2525
from ._manipulate import named_apply, MATCH_PREV_GROUP
2626
from ._registry import register_model, generate_default_cfgs
@@ -296,10 +296,10 @@ def __init__(
296296
if avg_down:
297297
self.conv_down = nn.Sequential(
298298
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
299-
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
299+
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
300300
)
301301
else:
302-
self.conv_down = ConvNormActAa(
302+
self.conv_down = ConvNormAct(
303303
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
304304
aa_layer=aa_layer, **conv_kwargs)
305305
prev_chs = down_chs
@@ -375,10 +375,10 @@ def __init__(
375375
if avg_down:
376376
self.conv_down = nn.Sequential(
377377
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
378-
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
378+
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
379379
)
380380
else:
381-
self.conv_down = ConvNormActAa(
381+
self.conv_down = ConvNormAct(
382382
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
383383
aa_layer=aa_layer, **conv_kwargs)
384384
prev_chs = down_chs
@@ -442,10 +442,10 @@ def __init__(
442442
if avg_down:
443443
self.conv_down = nn.Sequential(
444444
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
445-
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
445+
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
446446
)
447447
else:
448-
self.conv_down = ConvNormActAa(
448+
self.conv_down = ConvNormAct(
449449
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
450450
aa_layer=aa_layer, **conv_kwargs)
451451

timm/models/tresnet.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
16-
ConvNormActAa, ConvNormAct, DropPath
15+
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
1716
from ._builder import build_model_with_cfg
1817
from ._manipulate import checkpoint_seq
1918
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@@ -39,13 +38,8 @@ def __init__(
3938
self.stride = stride
4039
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
4140

42-
if stride == 1:
43-
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
44-
else:
45-
self.conv1 = ConvNormActAa(
46-
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
47-
48-
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
41+
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
42+
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False)
4943
self.act = nn.ReLU(inplace=True)
5044

5145
rd_chs = max(planes * self.expansion // 4, 64)
@@ -87,18 +81,14 @@ def __init__(
8781

8882
self.conv1 = ConvNormAct(
8983
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
90-
if stride == 1:
91-
self.conv2 = ConvNormAct(
92-
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
93-
else:
94-
self.conv2 = ConvNormActAa(
95-
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
84+
self.conv2 = ConvNormAct(
85+
planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
9686

9787
reduction_chs = max(planes * self.expansion // 8, 64)
9888
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
9989

10090
self.conv3 = ConvNormAct(
101-
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
91+
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False)
10292

10393
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
10494
self.act = nn.ReLU(inplace=True)
@@ -204,7 +194,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non
204194
# avg pooling before 1x1 conv
205195
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
206196
layers += [ConvNormAct(
207-
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
197+
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)]
208198
downsample = nn.Sequential(*layers)
209199

210200
layers = []

0 commit comments

Comments
 (0)