@@ -11,17 +11,15 @@ class PatchRandomErasing:
11
11
12
12
Supports three modes:
13
13
1. 'patch': Simple mode that erases randomly selected valid patches
14
- 2. 'region': Erases spatial regions at patch granularity
15
- 3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
16
- partially erasing patches that are on the boundary of the erased region
14
+ 2. 'region': Erases rectangular regions at patch granularity
17
15
18
16
Args:
19
17
erase_prob: Probability that the Random Erasing operation will be performed.
20
18
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
21
19
min_area: Minimum percentage of valid patches/area to erase.
22
20
max_area: Maximum percentage of valid patches/area to erase.
23
- min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
24
- max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
21
+ min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
22
+ max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
25
23
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
26
24
'const' - erase patch is constant color of 0 for all channels
27
25
'rand' - erase patch has same random (normal) value across all elements
@@ -45,7 +43,6 @@ def __init__(
45
43
mode : str = 'const' ,
46
44
value : float = 0. ,
47
45
spatial_mode : str = 'region' ,
48
- patch_size : Optional [Union [int , Tuple [int , int ]]] = 16 ,
49
46
num_splits : int = 0 ,
50
47
device : Union [str , torch .device ] = 'cuda' ,
51
48
):
@@ -66,14 +63,13 @@ def __init__(
66
63
67
64
# Strategy mode
68
65
self .spatial_mode = spatial_mode
69
-
70
- # Patch size (needed for subregion mode)
71
- self .patch_size = patch_size if isinstance (patch_size , tuple ) else (patch_size , patch_size )
66
+ assert self .spatial_mode in ('patch' , 'region' )
72
67
73
68
# Value generation mode flags
74
69
self .erase_mode = mode .lower ()
75
70
assert self .erase_mode in ('rand' , 'pixel' , 'const' )
76
71
self .const_value = value
72
+ self .unique_noise_per_patch = True
77
73
78
74
def _get_values (
79
75
self ,
@@ -156,27 +152,27 @@ def _erase_patches(
156
152
return
157
153
158
154
# Get indices of valid patches
159
- valid_indices = torch .nonzero (patch_valid , as_tuple = True )[0 ]. tolist ()
160
- if not valid_indices :
161
- # Skip if no valid patches
155
+ valid_indices = torch .nonzero (patch_valid , as_tuple = True )[0 ]
156
+ num_valid = len ( valid_indices )
157
+ if num_valid == 0 :
162
158
return
163
159
164
- num_valid = len (valid_indices )
165
160
count = random .randint (self .min_count , self .max_count )
166
161
# Determine how many valid patches to erase from RE min/max count and area args
167
- max_erase = max (1 , int (num_valid * count * self .max_area ))
162
+ max_erase = min ( num_valid , max (1 , int (num_valid * count * self .max_area ) ))
168
163
min_erase = max (1 , int (num_valid * count * self .min_area ))
169
164
num_erase = random .randint (min_erase , max_erase )
170
165
171
166
# Randomly select valid patches to erase
172
- indices_to_erase = random . sample ( valid_indices , min ( num_erase , num_valid ))
167
+ erase_idx = valid_indices [ torch . randperm ( num_valid , device = patches . device )[: num_erase ]]
173
168
174
- random_value = None
175
- if self .erase_mode == 'rand' :
176
- random_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
169
+ if self .unique_noise_per_patch and self .erase_mode == 'pixel' :
170
+ # generate unique noise for the whole selection of patches
171
+ fill_shape = (num_erase ,) + patch_shape
172
+ else :
173
+ fill_shape = patch_shape
177
174
178
- for idx in indices_to_erase :
179
- patches [idx ].copy_ (self ._get_values (patch_shape , dtype = dtype , value = random_value ))
175
+ patches [erase_idx ] = self ._get_values (fill_shape , dtype = dtype )
180
176
181
177
def _erase_region (
182
178
self ,
@@ -195,20 +191,14 @@ def _erase_region(
195
191
return
196
192
197
193
# Determine grid dimensions from coordinates
198
- if patch_valid is not None :
199
- valid_coord = patch_coord [patch_valid ]
200
- if len (valid_coord ) == 0 :
201
- return # No valid patches
202
- max_y = valid_coord [:, 0 ].max ().item () + 1
203
- max_x = valid_coord [:, 1 ].max ().item () + 1
204
- else :
205
- max_y = patch_coord [:, 0 ].max ().item () + 1
206
- max_x = patch_coord [:, 1 ].max ().item () + 1
207
-
194
+ valid_coord = patch_coord [patch_valid ]
195
+ if len (valid_coord ) == 0 :
196
+ return # No valid patches
197
+ max_y = valid_coord [:, 0 ].max ().item () + 1
198
+ max_x = valid_coord [:, 1 ].max ().item () + 1
208
199
grid_h , grid_w = max_y , max_x
209
-
210
- # Calculate total area
211
200
total_area = grid_h * grid_w
201
+ ys , xs = patch_coord [:, 0 ], patch_coord [:, 1 ]
212
202
213
203
count = random .randint (self .min_count , self .max_count )
214
204
for _ in range (count ):
@@ -222,132 +212,33 @@ def _erase_region(
222
212
h = int (round (math .sqrt (target_area * aspect_ratio )))
223
213
w = int (round (math .sqrt (target_area / aspect_ratio )))
224
214
225
- # Ensure region fits within grid
226
- if w <= grid_w and h <= grid_h :
227
- # Select random top-left corner
228
- top = random .randint (0 , grid_h - h )
229
- left = random .randint (0 , grid_w - w )
230
-
231
- # Define region bounds
232
- bottom = top + h
233
- right = left + w
234
-
235
- # Create a single random value for all affected patches if using 'rand' mode
236
- if self .erase_mode == 'rand' :
237
- random_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
238
- else :
239
- random_value = None
240
-
241
- # Find and erase all patches that fall within the region
242
- for i in range (len (patches )):
243
- if patch_valid is None or patch_valid [i ]:
244
- y , x = patch_coord [i ]
245
- if top <= y < bottom and left <= x < right :
246
- patches [i ] = self ._get_values (patch_shape , dtype = dtype , value = random_value )
247
-
248
- # Successfully applied erasing, exit the loop
249
- break
250
-
251
- def _erase_subregion (
252
- self ,
253
- patches : torch .Tensor ,
254
- patch_coord : torch .Tensor ,
255
- patch_valid : torch .Tensor ,
256
- patch_shape : torch .Size ,
257
- patch_size : Tuple [int , int ],
258
- dtype : torch .dtype = torch .float32 ,
259
- ):
260
- """Apply erasing by selecting rectangular regions ignoring patch boundaries.
215
+ if h > grid_h or w > grid_w :
216
+ continue # try again
261
217
262
- Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
263
- regions that are not aligned to patches (erase regions boundaries cut within patches).
218
+ # Calculate region patch bounds
219
+ top = random .randint (0 , grid_h - h )
220
+ left = random .randint (0 , grid_w - w )
221
+ bottom , right = top + h , left + w
264
222
265
- FIXME complexity probably not worth it, may remove.
266
- """
267
- if random .random () > self .erase_prob :
268
- return
269
-
270
- # Get patch dimensions
271
- patch_h , patch_w = patch_size
272
- channels = patch_shape [- 1 ]
273
-
274
- # Determine grid dimensions in patch coordinates
275
- if patch_valid is not None :
276
- valid_coord = patch_coord [patch_valid ]
277
- if len (valid_coord ) == 0 :
278
- return # No valid patches
279
- max_y = valid_coord [:, 0 ].max ().item () + 1
280
- max_x = valid_coord [:, 1 ].max ().item () + 1
281
- else :
282
- max_y = patch_coord [:, 0 ].max ().item () + 1
283
- max_x = patch_coord [:, 1 ].max ().item () + 1
284
-
285
- grid_h , grid_w = max_y , max_x
286
-
287
- # Calculate total area in pixel space
288
- total_area = (grid_h * patch_h ) * (grid_w * patch_w )
223
+ # Region test
224
+ region_mask = (
225
+ (ys >= top ) & (ys < bottom ) &
226
+ (xs >= left ) & (xs < right ) &
227
+ patch_valid
228
+ )
229
+ num_selected = int (region_mask .sum ().item ())
230
+ if not num_selected :
231
+ continue # no patch actually falls inside – try again
289
232
290
- count = random .randint (self .min_count , self .max_count )
291
- for _ in range (count ):
292
- # Try to select a valid region to erase (multiple attempts)
293
- for attempt in range (10 ):
294
- # Sample random area and aspect ratio
295
- target_area = random .uniform (self .min_area , self .max_area ) * total_area
296
- aspect_ratio = math .exp (random .uniform (* self .log_aspect_ratio ))
233
+ if self .unique_noise_per_patch and self .erase_mode == 'pixel' :
234
+ # generate unique noise for the whole region
235
+ fill_shape = (num_selected ,) + patch_shape
236
+ else :
237
+ fill_shape = patch_shape
297
238
298
- # Calculate region height and width in pixel space
299
- pixel_h = int (round (math .sqrt (target_area * aspect_ratio )))
300
- pixel_w = int (round (math .sqrt (target_area / aspect_ratio )))
301
-
302
- # Ensure region fits within total pixel grid
303
- if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h :
304
- # Select random top-left corner in pixel space
305
- pixel_top = random .randint (0 , grid_h * patch_h - pixel_h )
306
- pixel_left = random .randint (0 , grid_w * patch_w - pixel_w )
307
-
308
- # Define region bounds in pixel space
309
- pixel_bottom = pixel_top + pixel_h
310
- pixel_right = pixel_left + pixel_w
311
-
312
- # Create a single random value for the entire region if using 'rand' mode
313
- rand_value = None
314
- if self .erase_mode == 'rand' :
315
- rand_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
316
-
317
- # For each valid patch, determine if and how it overlaps with the erase region
318
- for i in range (len (patches )):
319
- if patch_valid is None or patch_valid [i ]:
320
- # Convert patch coordinates to pixel space (top-left corner)
321
- y , x = patch_coord [i ]
322
- patch_pixel_top = y * patch_h
323
- patch_pixel_left = x * patch_w
324
- patch_pixel_bottom = patch_pixel_top + patch_h
325
- patch_pixel_right = patch_pixel_left + patch_w
326
-
327
- # Check if this patch overlaps with the erase region
328
- if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
329
- patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom ):
330
-
331
- # Calculate the overlap region in patch-local coordinates
332
- local_top = max (0 , pixel_top - patch_pixel_top )
333
- local_left = max (0 , pixel_left - patch_pixel_left )
334
- local_bottom = min (patch_h , pixel_bottom - patch_pixel_top )
335
- local_right = min (patch_w , pixel_right - patch_pixel_left )
336
-
337
- # Reshape the patch to [patch_h, patch_w, chans]
338
- patch_data = patches [i ].reshape (patch_h , patch_w , channels )
339
-
340
- erase_shape = (local_bottom - local_top , local_right - local_left , channels )
341
- erase_value = self ._get_values (erase_shape , dtype = dtype , value = rand_value )
342
- patch_data [local_top :local_bottom , local_left :local_right , :] = erase_value
343
-
344
- # Flatten the patch back to [patch_h*patch_w, chans]
345
- if len (patch_shape ) == 2 :
346
- patch_data = patch_data .reshape (- 1 , channels )
347
- patches [i ] = patch_data
348
-
349
- # Successfully applied erasing, exit the loop
350
- break
239
+ patches [region_mask ] = self ._get_values (fill_shape , dtype = dtype )
240
+ # Successfully applied erasing, exit the loop
241
+ break
351
242
352
243
def __call__ (
353
244
self ,
@@ -369,18 +260,12 @@ def __call__(
369
260
"""
370
261
if patches .ndim == 4 :
371
262
batch_size , num_patches , patch_dim , channels = patches .shape
372
- if self .patch_size is not None :
373
- patch_size = self .patch_size
374
- else :
375
- patch_size = None
376
263
elif patches .ndim == 5 :
377
264
batch_size , num_patches , patch_h , patch_w , channels = patches .shape
378
- patch_size = (patch_h , patch_w )
379
265
else :
380
266
assert False
381
267
patch_shape = patches .shape [2 :]
382
268
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
383
- # patch_size ==> patch h, w (if available, must be avail for subregion mode)
384
269
385
270
# Create default valid mask if not provided
386
271
if patch_valid is None :
@@ -399,6 +284,7 @@ def __call__(
399
284
patch_valid [i ],
400
285
)
401
286
elif self .spatial_mode == 'patch' :
287
+ # FIXME we could vectorize patch mode across batch, worth the effort?
402
288
self ._erase_patches (
403
289
patches [i ],
404
290
patch_coord [i ],
@@ -414,15 +300,8 @@ def __call__(
414
300
patch_shape ,
415
301
patches .dtype
416
302
)
417
- elif self .spatial_mode == 'subregion' :
418
- self ._erase_subregion (
419
- patches [i ],
420
- patch_coord [i ],
421
- patch_valid [i ],
422
- patch_shape ,
423
- patch_size ,
424
- patches .dtype
425
- )
303
+ else :
304
+ assert False
426
305
427
306
return patches
428
307
0 commit comments