@@ -26,27 +26,24 @@ def mix_batch_variable_size(
26
26
cutmix_alpha : float = 1.0 ,
27
27
switch_prob : float = 0.5 ,
28
28
local_shuffle : int = 4 ,
29
- ) -> Tuple [List [torch .Tensor ], List [float ], Dict [int , int ], bool ]:
29
+ ) -> Tuple [List [torch .Tensor ], List [float ], Dict [int , int ]]:
30
30
"""Apply Mixup or CutMix on a batch of variable‑sized images.
31
31
32
32
The function first sorts images by aspect ratio and pairs neighbouring
33
33
samples (optionally shuffling within small windows so pairs vary between
34
34
epochs). Only the mutual central‑overlap region of each pair is mixed
35
35
36
36
Args:
37
- imgs: List of transformed images shaped (C, H, W). Heights and
38
- widths may differ between samples.
39
- mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup.
40
- cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix.
37
+ imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples.
38
+ mixup_alpha: Beta‑distribution alpha for Mixup. Set to 0 to disable Mixup.
39
+ cutmix_alpha: Beta‑distribution alpha for CutMix. Set to 0 to disable CutMix.
41
40
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
42
- local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
43
- A value of 0 turns shuffling off.
41
+ local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1.
44
42
45
43
Returns:
46
44
mixed_imgs: List of mixed images.
47
45
lam_list: Per‑sample lambda values representing the degree of mixing.
48
46
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
49
- use_cutmix: True if CutMix was used for this call, False if Mixup was used.
50
47
"""
51
48
if len (imgs ) < 2 :
52
49
raise ValueError ("Need at least two images to perform Mixup/CutMix." )
@@ -71,7 +68,7 @@ def mix_batch_variable_size(
71
68
order = sorted (range (len (imgs )), key = lambda i : imgs [i ].shape [2 ] / imgs [i ].shape [1 ])
72
69
if local_shuffle > 1 :
73
70
for start in range (0 , len (order ), local_shuffle ):
74
- random .shuffle (order [start : start + local_shuffle ])
71
+ random .shuffle (order [start :start + local_shuffle ])
75
72
76
73
pair_to : Dict [int , int ] = {}
77
74
for a , b in zip (order [::2 ], order [1 ::2 ]):
@@ -119,22 +116,41 @@ def mix_batch_variable_size(
119
116
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
120
117
else :
121
118
# Mixup: blend the entire overlap region
122
- patch_i = xi [:, top_i : top_i + oh , left_i : left_i + ow ]
123
- patch_j = xj [:, top_j : top_j + oh , left_j : left_j + ow ]
119
+ patch_i = xi [:, top_i :top_i + oh , left_i :left_i + ow ]
120
+ patch_j = xj [:, top_j :top_j + oh , left_j :left_j + ow ]
124
121
125
122
blended = patch_i .mul (lam_raw ).add_ (patch_j , alpha = 1.0 - lam_raw )
126
- xi [:, top_i : top_i + oh , left_i : left_i + ow ] = blended
123
+ xi [:, top_i :top_i + oh , left_i :left_i + ow ] = blended
127
124
mixed_imgs [i ] = xi
128
125
129
126
corrected_lam = (dest_area - overlap_area ) / dest_area + lam_raw * overlap_area / dest_area
130
127
lam_list [i ] = corrected_lam
131
128
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
132
129
133
- return mixed_imgs , lam_list , pair_to , use_cutmix
130
+ return mixed_imgs , lam_list , pair_to
131
+
132
+
133
+ def smoothed_sparse_target (
134
+ targets : torch .Tensor ,
135
+ * ,
136
+ num_classes : int ,
137
+ smoothing : float = 0.0 ,
138
+ ) -> torch .Tensor :
139
+ off_val = smoothing / num_classes
140
+ on_val = 1.0 - smoothing + off_val
141
+
142
+ y_onehot = torch .full (
143
+ (targets .size (0 ), num_classes ),
144
+ off_val ,
145
+ dtype = torch .float32 ,
146
+ device = targets .device
147
+ )
148
+ y_onehot .scatter_ (1 , targets .unsqueeze (1 ), on_val )
149
+ return y_onehot
134
150
135
151
136
152
def pairwise_mixup_target (
137
- labels : torch .Tensor ,
153
+ targets : torch .Tensor ,
138
154
pair_to : Dict [int , int ],
139
155
lam_list : List [float ],
140
156
* ,
@@ -144,21 +160,16 @@ def pairwise_mixup_target(
144
160
"""Create soft targets that match the pixel‑level mixing performed.
145
161
146
162
Args:
147
- labels : (B,) tensor of integer class indices.
163
+ targets : (B,) tensor of integer class indices.
148
164
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
149
- lam_list: Per‑sample fractions of self pixels, also from the mixer.
165
+ lam_list: Per‑sample fractions of own pixels, also from the mixer.
150
166
num_classes: Total number of classes in the dataset.
151
167
smoothing: Label‑smoothing value in the range [0, 1).
152
168
153
169
Returns:
154
170
Tensor of shape (B, num_classes) whose rows sum to 1.
155
171
"""
156
- off_val = smoothing / num_classes
157
- on_val = 1.0 - smoothing + off_val
158
-
159
- y_onehot = torch .full ((labels .size (0 ), num_classes ), off_val , dtype = torch .float32 , device = labels .device )
160
- y_onehot .scatter_ (1 , labels .unsqueeze (1 ), on_val )
161
-
172
+ y_onehot = smoothed_sparse_target (targets , num_classes = num_classes , smoothing = smoothing )
162
173
targets = y_onehot .clone ()
163
174
for i , j in pair_to .items ():
164
175
lam = lam_list [i ]
@@ -177,8 +188,9 @@ def __init__(
177
188
mixup_alpha : float = 0.8 ,
178
189
cutmix_alpha : float = 1.0 ,
179
190
switch_prob : float = 0.5 ,
191
+ prob : float = 1.0 ,
180
192
local_shuffle : int = 4 ,
181
- smoothing : float = 0.0 ,
193
+ label_smoothing : float = 0.0 ,
182
194
) -> None :
183
195
"""Configure the augmentation.
184
196
@@ -187,35 +199,41 @@ def __init__(
187
199
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
188
200
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
189
201
switch_prob: Probability of selecting CutMix when both modes are enabled.
202
+ prob: Probability of applying any mixing per batch.
190
203
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
191
204
smoothing: Label‑smoothing value. 0 disables smoothing.
192
205
"""
193
206
self .num_classes = num_classes
194
207
self .mixup_alpha = mixup_alpha
195
208
self .cutmix_alpha = cutmix_alpha
196
209
self .switch_prob = switch_prob
210
+ self .prob = prob
197
211
self .local_shuffle = local_shuffle
198
- self .smoothing = smoothing
212
+ self .smoothing = label_smoothing
199
213
200
214
def __call__ (
201
215
self ,
202
216
imgs : List [torch .Tensor ],
203
- labels : torch .Tensor ,
204
- ) -> Tuple [List [torch .Tensor ], torch .Tensor ]:
217
+ targets : torch .Tensor ,
218
+ ) -> Tuple [List [torch .Tensor ], List [ torch .Tensor ] ]:
205
219
"""Apply the augmentation and generate matching targets.
206
220
207
221
Args:
208
- imgs: List of already‑ transformed images shaped (C, H, W).
209
- labels : Hard labels with shape (B,).
222
+ imgs: List of already transformed images shaped (C, H, W).
223
+ targets : Hard labels with shape (B,).
210
224
211
225
Returns:
212
226
mixed_imgs: List of mixed images in the same order and shapes as the input.
213
227
targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
214
228
"""
215
- if isinstance (labels , (list , tuple )):
216
- labels = torch .tensor (labels )
229
+ if not isinstance (targets , torch .Tensor ):
230
+ targets = torch .tensor (targets )
231
+
232
+ if random .random () > self .prob :
233
+ targets = smoothed_sparse_target (targets , num_classes = self .num_classes , smoothing = self .smoothing )
234
+ return imgs , targets .unbind (0 )
217
235
218
- mixed_imgs , lam_list , pair_to , _ = mix_batch_variable_size (
236
+ mixed_imgs , lam_list , pair_to = mix_batch_variable_size (
219
237
imgs ,
220
238
mixup_alpha = self .mixup_alpha ,
221
239
cutmix_alpha = self .cutmix_alpha ,
@@ -224,7 +242,7 @@ def __call__(
224
242
)
225
243
226
244
targets = pairwise_mixup_target (
227
- labels ,
245
+ targets ,
228
246
pair_to ,
229
247
lam_list ,
230
248
num_classes = self .num_classes ,
0 commit comments