Skip to content

Commit d7d3538

Browse files
committed
Add so400m model size for test, few tweaks.
1 parent 7bfe606 commit d7d3538

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

timm/data/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11-
from .naflex_dataset import VariableSeqMapWrapper
11+
from .naflex_dataset import VariableSeqMapWrapper, calculate_naflex_batch_size
1212
from .naflex_loader import create_naflex_loader
1313
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
1414
from .naflex_transforms import (
@@ -17,6 +17,8 @@
1717
RandomCropToSequence,
1818
RandomResizedCropToSequence,
1919
ResizeKeepRatioToSequence,
20+
Patchify,
21+
patchify_image,
2022
)
2123
from .readers import create_reader
2224
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions

timm/data/naflex_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from PIL import Image
2525

2626

27-
from .naflex_transforms import Patchify, patchify
27+
from .naflex_transforms import Patchify, patchify_image
2828

2929

30-
def calculate_batch_size(
30+
def calculate_naflex_batch_size(
3131
tokens_per_batch: int,
3232
seq_len: int,
3333
max_size: Optional[int] = None,
@@ -240,7 +240,7 @@ def _create_canonical_schedule(self):
240240
seq_len = self.seq_lens[seq_idx]
241241

242242
# Calculate batch size
243-
batch_size = calculate_batch_size(
243+
batch_size = calculate_naflex_batch_size(
244244
tokens_per_batch=self.max_tokens_per_batch,
245245
seq_len=seq_len,
246246
# max_size should be remaining_samples to avoid overshooting

timm/data/naflex_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def __repr__(self) -> str:
738738
return format_string
739739

740740

741-
def patchify(
741+
def patchify_image(
742742
img: torch.Tensor,
743743
patch_size: Tuple[int, int],
744744
pad: bool = True,
@@ -794,7 +794,7 @@ def forward(self, img):
794794
# Convert PIL Image to tensor [C, H, W]
795795
img = transforms.functional.to_tensor(img)
796796

797-
patches, coord, valid = patchify(img, self.patch_size)
797+
patches, coord, valid = patchify_image(img, self.patch_size)
798798

799799
return {
800800
'patches': patches,

timm/models/vision_transformer_flex.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
10271027
default_cfgs = generate_default_cfgs({
10281028
'vit_naflex_base_patch16_gap': _cfg(),
10291029
'vit_naflex_base_patch16_map': _cfg(),
1030+
'vit_naflex_so400m_patch16_map': _cfg(),
10301031

10311032
# sbb model testijg
10321033
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
@@ -1110,3 +1111,15 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
11101111
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
11111112
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
11121113
return model
1114+
1115+
1116+
@register_model
1117+
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
1118+
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
1119+
"""
1120+
model_args = dict(
1121+
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
1122+
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
1123+
model = _create_vision_transformer_flex(
1124+
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
1125+
return model

0 commit comments

Comments
 (0)