Skip to content

Commit 7624389

Browse files
committed
Mixup cleanup, add prob support and train script integration. Add working loader based patch compatible RandomErasing for NaFlex mode.
1 parent 8fcbceb commit 7624389

7 files changed

+591
-74
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .mixup import Mixup, FastCollateMixup
1111
from .naflex_dataset import VariableSeqMapWrapper
1212
from .naflex_loader import create_naflex_loader
13+
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
1314
from .naflex_transforms import (
1415
ResizeToSequence,
1516
CenterCropToSequence,

timm/data/naflex_dataset.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,11 @@ def __call__(self, batch):
8383
batch_size = len(batch)
8484

8585
# Extract targets
86-
# FIXME need to handle dense (float) targets or always done downstream of this?
87-
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
86+
targets = [item[1] for item in batch]
87+
if isinstance(targets[0], torch.Tensor):
88+
targets = torch.stack(targets)
89+
else:
90+
targets = torch.tensor(targets, dtype=torch.int64)
8891

8992
# Get patch dictionaries
9093
patch_dicts = [item[0] for item in batch]
@@ -139,6 +142,7 @@ def __init__(
139142
seq_lens: List[int] = (128, 256, 576, 784, 1024),
140143
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
141144
transform_factory: Optional[Callable] = None,
145+
mixup_fn: Optional[Callable] = None,
142146
seed: int = 42,
143147
shuffle: bool = True,
144148
distributed: bool = False,
@@ -172,6 +176,7 @@ def __init__(
172176
else:
173177
self.transforms[seq_len] = None # No transform
174178
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
179+
self.mixup_fn = mixup_fn
175180
self.patchifier = Patchify(self.patch_size)
176181

177182
# --- Canonical Schedule Calculation (Done Once) ---
@@ -393,6 +398,8 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
393398
transform = self.transforms.get(seq_len)
394399

395400
batch_samples = []
401+
batch_imgs = []
402+
batch_targets = []
396403
for idx in indices:
397404
try:
398405
# Get original image and label from map-style dataset
@@ -405,9 +412,8 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
405412
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
406413
continue
407414

408-
# Apply patching
409-
patch_data = self.patchifier(processed_img)
410-
batch_samples.append((patch_data, label))
415+
batch_imgs.append(processed_img)
416+
batch_targets.append(label)
411417

412418
except IndexError:
413419
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
@@ -417,8 +423,13 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
417423
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
418424
continue # Skip problematic sample
419425

420-
# Collate the processed samples into a batch
426+
if self.mixup_fn is not None:
427+
batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
428+
429+
batch_imgs = [self.patchifier(img) for img in batch_imgs]
430+
batch_samples = list(zip(batch_imgs, batch_targets))
421431
if batch_samples: # Only yield if we successfully processed samples
432+
# Collate the processed samples into a batch
422433
yield self.collate_fns[seq_len](batch_samples)
423434

424435
# If batch_samples is empty after processing 'indices', an empty batch is skipped.

timm/data/naflex_loader.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from functools import partial
44
from typing import Callable, List, Optional, Tuple, Union
55

6+
67
import torch
78

89
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9-
from .loader import _worker_init
10+
from .loader import _worker_init, adapt_to_chs
1011
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
12+
from .naflex_random_erasing import PatchRandomErasing
1113
from .transforms_factory import create_transform
1214

1315

@@ -16,19 +18,41 @@ class NaFlexPrefetchLoader:
1618

1719
def __init__(
1820
self,
19-
loader,
20-
mean=(0.485, 0.456, 0.406),
21-
std=(0.229, 0.224, 0.225),
22-
img_dtype=torch.float32,
23-
device=torch.device('cuda')
21+
loader: torch.utils.data.DataLoader,
22+
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
23+
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
24+
channels: int = 3,
25+
device: torch.device = torch.device('cuda'),
26+
img_dtype: Optional[torch.dtype] = None,
27+
re_prob: float = 0.,
28+
re_mode: str = 'const',
29+
re_count: int = 1,
30+
re_num_splits: int = 0,
2431
):
2532
self.loader = loader
2633
self.device = device
2734
self.img_dtype = img_dtype or torch.float32
2835

2936
# Create mean/std tensors for normalization (will be applied to patches)
30-
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
31-
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
37+
mean = adapt_to_chs(mean, channels)
38+
std = adapt_to_chs(std, channels)
39+
normalization_shape = (1, 1, channels)
40+
self.channels = channels
41+
self.mean = torch.tensor(
42+
[x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
43+
self.std = torch.tensor(
44+
[x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
45+
46+
if re_prob > 0.:
47+
self.random_erasing = PatchRandomErasing(
48+
erase_prob=re_prob,
49+
mode=re_mode,
50+
max_count=re_count,
51+
num_splits=re_num_splits,
52+
device=device,
53+
)
54+
else:
55+
self.random_erasing = None
3256

3357
# Check for CUDA/NPU availability
3458
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
@@ -62,9 +86,18 @@ def __iter__(self):
6286

6387
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
6488
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
65-
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
89+
90+
# To [B*N, P*P, C] for normalization and erasing
91+
patches = next_input_dict['patches'].view(batch_size, num_patches, -1, self.channels)
6692
patches = patches.sub(self.mean).div(self.std)
6793

94+
if self.random_erasing is not None:
95+
patches = self.random_erasing(
96+
patches,
97+
patch_coord=next_input_dict['patch_coord'],
98+
patch_valid=next_input_dict.get('patch_valid', None),
99+
)
100+
68101
# Reshape back
69102
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
70103

@@ -103,6 +136,7 @@ def create_naflex_loader(
103136
max_seq_len: int = 576, # Fixed sequence length for validation
104137
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
105138
is_training: bool = False,
139+
mixup_fn: Optional[Callable] = None,
106140

107141
no_aug: bool = False,
108142
re_prob: float = 0.,
@@ -141,7 +175,8 @@ def create_naflex_loader(
141175
persistent_workers: bool = True,
142176
worker_seeding: str = 'all',
143177
):
144-
"""Create a data loader with dynamic sequence length sampling for training."""
178+
"""Create a data loader with dynamic sequence length sampling for training.
179+
"""
145180

146181
if is_training:
147182
# For training, use the dynamic sequence length mechanism
@@ -186,6 +221,7 @@ def create_naflex_loader(
186221
patch_size=patch_size,
187222
seq_lens=train_seq_lens,
188223
max_tokens_per_batch=max_tokens_per_batch,
224+
mixup_fn=mixup_fn,
189225
seed=seed,
190226
distributed=distributed,
191227
rank=rank,
@@ -219,6 +255,9 @@ def create_naflex_loader(
219255
std=std,
220256
img_dtype=img_dtype,
221257
device=device,
258+
re_prob=re_prob,
259+
re_mode=re_mode,
260+
re_count=re_count,
222261
)
223262

224263
else:

timm/data/naflex_mixup.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,24 @@ def mix_batch_variable_size(
2626
cutmix_alpha: float = 1.0,
2727
switch_prob: float = 0.5,
2828
local_shuffle: int = 4,
29-
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]:
29+
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]:
3030
"""Apply Mixup or CutMix on a batch of variable‑sized images.
3131
3232
The function first sorts images by aspect ratio and pairs neighbouring
3333
samples (optionally shuffling within small windows so pairs vary between
3434
epochs). Only the mutual central‑overlap region of each pair is mixed
3535
3636
Args:
37-
imgs: List of transformed images shaped (C, H, W). Heights and
38-
widths may differ between samples.
39-
mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup.
40-
cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix.
37+
imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples.
38+
mixup_alpha: Beta‑distribution alpha for Mixup. Set to 0 to disable Mixup.
39+
cutmix_alpha: Beta‑distribution alpha for CutMix. Set to 0 to disable CutMix.
4140
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
42-
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
43-
A value of 0 turns shuffling off.
41+
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1.
4442
4543
Returns:
4644
mixed_imgs: List of mixed images.
4745
lam_list: Per‑sample lambda values representing the degree of mixing.
4846
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
49-
use_cutmix: True if CutMix was used for this call, False if Mixup was used.
5047
"""
5148
if len(imgs) < 2:
5249
raise ValueError("Need at least two images to perform Mixup/CutMix.")
@@ -71,7 +68,7 @@ def mix_batch_variable_size(
7168
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
7269
if local_shuffle > 1:
7370
for start in range(0, len(order), local_shuffle):
74-
random.shuffle(order[start: start + local_shuffle])
71+
random.shuffle(order[start:start + local_shuffle])
7572

7673
pair_to: Dict[int, int] = {}
7774
for a, b in zip(order[::2], order[1::2]):
@@ -119,22 +116,41 @@ def mix_batch_variable_size(
119116
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
120117
else:
121118
# Mixup: blend the entire overlap region
122-
patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow]
123-
patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow]
119+
patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
120+
patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow]
124121

125122
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
126-
xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended
123+
xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended
127124
mixed_imgs[i] = xi
128125

129126
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
130127
lam_list[i] = corrected_lam
131128
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
132129

133-
return mixed_imgs, lam_list, pair_to, use_cutmix
130+
return mixed_imgs, lam_list, pair_to
131+
132+
133+
def smoothed_sparse_target(
134+
targets: torch.Tensor,
135+
*,
136+
num_classes: int,
137+
smoothing: float = 0.0,
138+
) -> torch.Tensor:
139+
off_val = smoothing / num_classes
140+
on_val = 1.0 - smoothing + off_val
141+
142+
y_onehot = torch.full(
143+
(targets.size(0), num_classes),
144+
off_val,
145+
dtype=torch.float32,
146+
device=targets.device
147+
)
148+
y_onehot.scatter_(1, targets.unsqueeze(1), on_val)
149+
return y_onehot
134150

135151

136152
def pairwise_mixup_target(
137-
labels: torch.Tensor,
153+
targets: torch.Tensor,
138154
pair_to: Dict[int, int],
139155
lam_list: List[float],
140156
*,
@@ -144,21 +160,16 @@ def pairwise_mixup_target(
144160
"""Create soft targets that match the pixel‑level mixing performed.
145161
146162
Args:
147-
labels: (B,) tensor of integer class indices.
163+
targets: (B,) tensor of integer class indices.
148164
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
149-
lam_list: Per‑sample fractions of self pixels, also from the mixer.
165+
lam_list: Per‑sample fractions of own pixels, also from the mixer.
150166
num_classes: Total number of classes in the dataset.
151167
smoothing: Label‑smoothing value in the range [0, 1).
152168
153169
Returns:
154170
Tensor of shape (B, num_classes) whose rows sum to 1.
155171
"""
156-
off_val = smoothing / num_classes
157-
on_val = 1.0 - smoothing + off_val
158-
159-
y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device)
160-
y_onehot.scatter_(1, labels.unsqueeze(1), on_val)
161-
172+
y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing)
162173
targets = y_onehot.clone()
163174
for i, j in pair_to.items():
164175
lam = lam_list[i]
@@ -177,8 +188,9 @@ def __init__(
177188
mixup_alpha: float = 0.8,
178189
cutmix_alpha: float = 1.0,
179190
switch_prob: float = 0.5,
191+
prob: float = 1.0,
180192
local_shuffle: int = 4,
181-
smoothing: float = 0.0,
193+
label_smoothing: float = 0.0,
182194
) -> None:
183195
"""Configure the augmentation.
184196
@@ -187,35 +199,41 @@ def __init__(
187199
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
188200
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
189201
switch_prob: Probability of selecting CutMix when both modes are enabled.
202+
prob: Probability of applying any mixing per batch.
190203
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
191204
smoothing: Label‑smoothing value. 0 disables smoothing.
192205
"""
193206
self.num_classes = num_classes
194207
self.mixup_alpha = mixup_alpha
195208
self.cutmix_alpha = cutmix_alpha
196209
self.switch_prob = switch_prob
210+
self.prob = prob
197211
self.local_shuffle = local_shuffle
198-
self.smoothing = smoothing
212+
self.smoothing = label_smoothing
199213

200214
def __call__(
201215
self,
202216
imgs: List[torch.Tensor],
203-
labels: torch.Tensor,
204-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
217+
targets: torch.Tensor,
218+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
205219
"""Apply the augmentation and generate matching targets.
206220
207221
Args:
208-
imgs: List of alreadytransformed images shaped (C, H, W).
209-
labels: Hard labels with shape (B,).
222+
imgs: List of already transformed images shaped (C, H, W).
223+
targets: Hard labels with shape (B,).
210224
211225
Returns:
212226
mixed_imgs: List of mixed images in the same order and shapes as the input.
213227
targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
214228
"""
215-
if isinstance(labels, (list, tuple)):
216-
labels = torch.tensor(labels)
229+
if not isinstance(targets, torch.Tensor):
230+
targets = torch.tensor(targets)
231+
232+
if random.random() > self.prob:
233+
targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing)
234+
return imgs, targets.unbind(0)
217235

218-
mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size(
236+
mixed_imgs, lam_list, pair_to = mix_batch_variable_size(
219237
imgs,
220238
mixup_alpha=self.mixup_alpha,
221239
cutmix_alpha=self.cutmix_alpha,
@@ -224,7 +242,7 @@ def __call__(
224242
)
225243

226244
targets = pairwise_mixup_target(
227-
labels,
245+
targets,
228246
pair_to,
229247
lam_list,
230248
num_classes=self.num_classes,

0 commit comments

Comments
 (0)