@@ -26,27 +26,24 @@ def mix_batch_variable_size(
2626 cutmix_alpha : float = 1.0 ,
2727 switch_prob : float = 0.5 ,
2828 local_shuffle : int = 4 ,
29- ) -> Tuple [List [torch .Tensor ], List [float ], Dict [int , int ], bool ]:
29+ ) -> Tuple [List [torch .Tensor ], List [float ], Dict [int , int ]]:
3030 """Apply Mixup or CutMix on a batch of variable‑sized images.
3131
3232 The function first sorts images by aspect ratio and pairs neighbouring
3333 samples (optionally shuffling within small windows so pairs vary between
3434 epochs). Only the mutual central‑overlap region of each pair is mixed
3535
3636 Args:
37- imgs: List of transformed images shaped (C, H, W). Heights and
38- widths may differ between samples.
39- mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup.
40- cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix.
37+ imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples.
38+ mixup_alpha: Beta‑distribution alpha for Mixup. Set to 0 to disable Mixup.
39+ cutmix_alpha: Beta‑distribution alpha for CutMix. Set to 0 to disable CutMix.
4140 switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
42- local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
43- A value of 0 turns shuffling off.
41+ local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1.
4442
4543 Returns:
4644 mixed_imgs: List of mixed images.
4745 lam_list: Per‑sample lambda values representing the degree of mixing.
4846 pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
49- use_cutmix: True if CutMix was used for this call, False if Mixup was used.
5047 """
5148 if len (imgs ) < 2 :
5249 raise ValueError ("Need at least two images to perform Mixup/CutMix." )
@@ -71,7 +68,7 @@ def mix_batch_variable_size(
7168 order = sorted (range (len (imgs )), key = lambda i : imgs [i ].shape [2 ] / imgs [i ].shape [1 ])
7269 if local_shuffle > 1 :
7370 for start in range (0 , len (order ), local_shuffle ):
74- random .shuffle (order [start : start + local_shuffle ])
71+ random .shuffle (order [start :start + local_shuffle ])
7572
7673 pair_to : Dict [int , int ] = {}
7774 for a , b in zip (order [::2 ], order [1 ::2 ]):
@@ -119,22 +116,41 @@ def mix_batch_variable_size(
119116 #print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
120117 else :
121118 # Mixup: blend the entire overlap region
122- patch_i = xi [:, top_i : top_i + oh , left_i : left_i + ow ]
123- patch_j = xj [:, top_j : top_j + oh , left_j : left_j + ow ]
119+ patch_i = xi [:, top_i :top_i + oh , left_i :left_i + ow ]
120+ patch_j = xj [:, top_j :top_j + oh , left_j :left_j + ow ]
124121
125122 blended = patch_i .mul (lam_raw ).add_ (patch_j , alpha = 1.0 - lam_raw )
126- xi [:, top_i : top_i + oh , left_i : left_i + ow ] = blended
123+ xi [:, top_i :top_i + oh , left_i :left_i + ow ] = blended
127124 mixed_imgs [i ] = xi
128125
129126 corrected_lam = (dest_area - overlap_area ) / dest_area + lam_raw * overlap_area / dest_area
130127 lam_list [i ] = corrected_lam
131128 #print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
132129
133- return mixed_imgs , lam_list , pair_to , use_cutmix
130+ return mixed_imgs , lam_list , pair_to
131+
132+
133+ def smoothed_sparse_target (
134+ targets : torch .Tensor ,
135+ * ,
136+ num_classes : int ,
137+ smoothing : float = 0.0 ,
138+ ) -> torch .Tensor :
139+ off_val = smoothing / num_classes
140+ on_val = 1.0 - smoothing + off_val
141+
142+ y_onehot = torch .full (
143+ (targets .size (0 ), num_classes ),
144+ off_val ,
145+ dtype = torch .float32 ,
146+ device = targets .device
147+ )
148+ y_onehot .scatter_ (1 , targets .unsqueeze (1 ), on_val )
149+ return y_onehot
134150
135151
136152def pairwise_mixup_target (
137- labels : torch .Tensor ,
153+ targets : torch .Tensor ,
138154 pair_to : Dict [int , int ],
139155 lam_list : List [float ],
140156 * ,
@@ -144,21 +160,16 @@ def pairwise_mixup_target(
144160 """Create soft targets that match the pixel‑level mixing performed.
145161
146162 Args:
147- labels : (B,) tensor of integer class indices.
163+ targets : (B,) tensor of integer class indices.
148164 pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
149- lam_list: Per‑sample fractions of self pixels, also from the mixer.
165+ lam_list: Per‑sample fractions of own pixels, also from the mixer.
150166 num_classes: Total number of classes in the dataset.
151167 smoothing: Label‑smoothing value in the range [0, 1).
152168
153169 Returns:
154170 Tensor of shape (B, num_classes) whose rows sum to 1.
155171 """
156- off_val = smoothing / num_classes
157- on_val = 1.0 - smoothing + off_val
158-
159- y_onehot = torch .full ((labels .size (0 ), num_classes ), off_val , dtype = torch .float32 , device = labels .device )
160- y_onehot .scatter_ (1 , labels .unsqueeze (1 ), on_val )
161-
172+ y_onehot = smoothed_sparse_target (targets , num_classes = num_classes , smoothing = smoothing )
162173 targets = y_onehot .clone ()
163174 for i , j in pair_to .items ():
164175 lam = lam_list [i ]
@@ -177,8 +188,9 @@ def __init__(
177188 mixup_alpha : float = 0.8 ,
178189 cutmix_alpha : float = 1.0 ,
179190 switch_prob : float = 0.5 ,
191+ prob : float = 1.0 ,
180192 local_shuffle : int = 4 ,
181- smoothing : float = 0.0 ,
193+ label_smoothing : float = 0.0 ,
182194 ) -> None :
183195 """Configure the augmentation.
184196
@@ -187,35 +199,41 @@ def __init__(
187199 mixup_alpha: Beta α for Mixup. 0 disables Mixup.
188200 cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
189201 switch_prob: Probability of selecting CutMix when both modes are enabled.
202+ prob: Probability of applying any mixing per batch.
190203 local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
191204 smoothing: Label‑smoothing value. 0 disables smoothing.
192205 """
193206 self .num_classes = num_classes
194207 self .mixup_alpha = mixup_alpha
195208 self .cutmix_alpha = cutmix_alpha
196209 self .switch_prob = switch_prob
210+ self .prob = prob
197211 self .local_shuffle = local_shuffle
198- self .smoothing = smoothing
212+ self .smoothing = label_smoothing
199213
200214 def __call__ (
201215 self ,
202216 imgs : List [torch .Tensor ],
203- labels : torch .Tensor ,
204- ) -> Tuple [List [torch .Tensor ], torch .Tensor ]:
217+ targets : torch .Tensor ,
218+ ) -> Tuple [List [torch .Tensor ], List [ torch .Tensor ] ]:
205219 """Apply the augmentation and generate matching targets.
206220
207221 Args:
208- imgs: List of already‑ transformed images shaped (C, H, W).
209- labels : Hard labels with shape (B,).
222+ imgs: List of already transformed images shaped (C, H, W).
223+ targets : Hard labels with shape (B,).
210224
211225 Returns:
212226 mixed_imgs: List of mixed images in the same order and shapes as the input.
213227 targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
214228 """
215- if isinstance (labels , (list , tuple )):
216- labels = torch .tensor (labels )
229+ if not isinstance (targets , torch .Tensor ):
230+ targets = torch .tensor (targets )
231+
232+ if random .random () > self .prob :
233+ targets = smoothed_sparse_target (targets , num_classes = self .num_classes , smoothing = self .smoothing )
234+ return imgs , targets .unbind (0 )
217235
218- mixed_imgs , lam_list , pair_to , _ = mix_batch_variable_size (
236+ mixed_imgs , lam_list , pair_to = mix_batch_variable_size (
219237 imgs ,
220238 mixup_alpha = self .mixup_alpha ,
221239 cutmix_alpha = self .cutmix_alpha ,
@@ -224,7 +242,7 @@ def __call__(
224242 )
225243
226244 targets = pairwise_mixup_target (
227- labels ,
245+ targets ,
228246 pair_to ,
229247 lam_list ,
230248 num_classes = self .num_classes ,
0 commit comments