@@ -11,17 +11,15 @@ class PatchRandomErasing:
1111
1212 Supports three modes:
1313 1. 'patch': Simple mode that erases randomly selected valid patches
14- 2. 'region': Erases spatial regions at patch granularity
15- 3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
16- partially erasing patches that are on the boundary of the erased region
14+ 2. 'region': Erases rectangular regions at patch granularity
1715
1816 Args:
1917 erase_prob: Probability that the Random Erasing operation will be performed.
2018 patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
2119 min_area: Minimum percentage of valid patches/area to erase.
2220 max_area: Maximum percentage of valid patches/area to erase.
23- min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
24- max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
21+ min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
22+ max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
2523 mode: Patch content mode, one of 'const', 'rand', or 'pixel'
2624 'const' - erase patch is constant color of 0 for all channels
2725 'rand' - erase patch has same random (normal) value across all elements
@@ -45,7 +43,6 @@ def __init__(
4543 mode : str = 'const' ,
4644 value : float = 0. ,
4745 spatial_mode : str = 'region' ,
48- patch_size : Optional [Union [int , Tuple [int , int ]]] = 16 ,
4946 num_splits : int = 0 ,
5047 device : Union [str , torch .device ] = 'cuda' ,
5148 ):
@@ -66,14 +63,13 @@ def __init__(
6663
6764 # Strategy mode
6865 self .spatial_mode = spatial_mode
69-
70- # Patch size (needed for subregion mode)
71- self .patch_size = patch_size if isinstance (patch_size , tuple ) else (patch_size , patch_size )
66+ assert self .spatial_mode in ('patch' , 'region' )
7267
7368 # Value generation mode flags
7469 self .erase_mode = mode .lower ()
7570 assert self .erase_mode in ('rand' , 'pixel' , 'const' )
7671 self .const_value = value
72+ self .unique_noise_per_patch = True
7773
7874 def _get_values (
7975 self ,
@@ -156,27 +152,27 @@ def _erase_patches(
156152 return
157153
158154 # Get indices of valid patches
159- valid_indices = torch .nonzero (patch_valid , as_tuple = True )[0 ]. tolist ()
160- if not valid_indices :
161- # Skip if no valid patches
155+ valid_indices = torch .nonzero (patch_valid , as_tuple = True )[0 ]
156+ num_valid = len ( valid_indices )
157+ if num_valid == 0 :
162158 return
163159
164- num_valid = len (valid_indices )
165160 count = random .randint (self .min_count , self .max_count )
166161 # Determine how many valid patches to erase from RE min/max count and area args
167- max_erase = max (1 , int (num_valid * count * self .max_area ))
162+ max_erase = min ( num_valid , max (1 , int (num_valid * count * self .max_area ) ))
168163 min_erase = max (1 , int (num_valid * count * self .min_area ))
169164 num_erase = random .randint (min_erase , max_erase )
170165
171166 # Randomly select valid patches to erase
172- indices_to_erase = random . sample ( valid_indices , min ( num_erase , num_valid ))
167+ erase_idx = valid_indices [ torch . randperm ( num_valid , device = patches . device )[: num_erase ]]
173168
174- random_value = None
175- if self .erase_mode == 'rand' :
176- random_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
169+ if self .unique_noise_per_patch and self .erase_mode == 'pixel' :
170+ # generate unique noise for the whole selection of patches
171+ fill_shape = (num_erase ,) + patch_shape
172+ else :
173+ fill_shape = patch_shape
177174
178- for idx in indices_to_erase :
179- patches [idx ].copy_ (self ._get_values (patch_shape , dtype = dtype , value = random_value ))
175+ patches [erase_idx ] = self ._get_values (fill_shape , dtype = dtype )
180176
181177 def _erase_region (
182178 self ,
@@ -195,20 +191,14 @@ def _erase_region(
195191 return
196192
197193 # Determine grid dimensions from coordinates
198- if patch_valid is not None :
199- valid_coord = patch_coord [patch_valid ]
200- if len (valid_coord ) == 0 :
201- return # No valid patches
202- max_y = valid_coord [:, 0 ].max ().item () + 1
203- max_x = valid_coord [:, 1 ].max ().item () + 1
204- else :
205- max_y = patch_coord [:, 0 ].max ().item () + 1
206- max_x = patch_coord [:, 1 ].max ().item () + 1
207-
194+ valid_coord = patch_coord [patch_valid ]
195+ if len (valid_coord ) == 0 :
196+ return # No valid patches
197+ max_y = valid_coord [:, 0 ].max ().item () + 1
198+ max_x = valid_coord [:, 1 ].max ().item () + 1
208199 grid_h , grid_w = max_y , max_x
209-
210- # Calculate total area
211200 total_area = grid_h * grid_w
201+ ys , xs = patch_coord [:, 0 ], patch_coord [:, 1 ]
212202
213203 count = random .randint (self .min_count , self .max_count )
214204 for _ in range (count ):
@@ -222,132 +212,33 @@ def _erase_region(
222212 h = int (round (math .sqrt (target_area * aspect_ratio )))
223213 w = int (round (math .sqrt (target_area / aspect_ratio )))
224214
225- # Ensure region fits within grid
226- if w <= grid_w and h <= grid_h :
227- # Select random top-left corner
228- top = random .randint (0 , grid_h - h )
229- left = random .randint (0 , grid_w - w )
230-
231- # Define region bounds
232- bottom = top + h
233- right = left + w
234-
235- # Create a single random value for all affected patches if using 'rand' mode
236- if self .erase_mode == 'rand' :
237- random_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
238- else :
239- random_value = None
240-
241- # Find and erase all patches that fall within the region
242- for i in range (len (patches )):
243- if patch_valid is None or patch_valid [i ]:
244- y , x = patch_coord [i ]
245- if top <= y < bottom and left <= x < right :
246- patches [i ] = self ._get_values (patch_shape , dtype = dtype , value = random_value )
247-
248- # Successfully applied erasing, exit the loop
249- break
250-
251- def _erase_subregion (
252- self ,
253- patches : torch .Tensor ,
254- patch_coord : torch .Tensor ,
255- patch_valid : torch .Tensor ,
256- patch_shape : torch .Size ,
257- patch_size : Tuple [int , int ],
258- dtype : torch .dtype = torch .float32 ,
259- ):
260- """Apply erasing by selecting rectangular regions ignoring patch boundaries.
215+ if h > grid_h or w > grid_w :
216+ continue # try again
261217
262- Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
263- regions that are not aligned to patches (erase regions boundaries cut within patches).
218+ # Calculate region patch bounds
219+ top = random .randint (0 , grid_h - h )
220+ left = random .randint (0 , grid_w - w )
221+ bottom , right = top + h , left + w
264222
265- FIXME complexity probably not worth it, may remove.
266- """
267- if random .random () > self .erase_prob :
268- return
269-
270- # Get patch dimensions
271- patch_h , patch_w = patch_size
272- channels = patch_shape [- 1 ]
273-
274- # Determine grid dimensions in patch coordinates
275- if patch_valid is not None :
276- valid_coord = patch_coord [patch_valid ]
277- if len (valid_coord ) == 0 :
278- return # No valid patches
279- max_y = valid_coord [:, 0 ].max ().item () + 1
280- max_x = valid_coord [:, 1 ].max ().item () + 1
281- else :
282- max_y = patch_coord [:, 0 ].max ().item () + 1
283- max_x = patch_coord [:, 1 ].max ().item () + 1
284-
285- grid_h , grid_w = max_y , max_x
286-
287- # Calculate total area in pixel space
288- total_area = (grid_h * patch_h ) * (grid_w * patch_w )
223+ # Region test
224+ region_mask = (
225+ (ys >= top ) & (ys < bottom ) &
226+ (xs >= left ) & (xs < right ) &
227+ patch_valid
228+ )
229+ num_selected = int (region_mask .sum ().item ())
230+ if not num_selected :
231+ continue # no patch actually falls inside – try again
289232
290- count = random .randint (self .min_count , self .max_count )
291- for _ in range (count ):
292- # Try to select a valid region to erase (multiple attempts)
293- for attempt in range (10 ):
294- # Sample random area and aspect ratio
295- target_area = random .uniform (self .min_area , self .max_area ) * total_area
296- aspect_ratio = math .exp (random .uniform (* self .log_aspect_ratio ))
233+ if self .unique_noise_per_patch and self .erase_mode == 'pixel' :
234+ # generate unique noise for the whole region
235+ fill_shape = (num_selected ,) + patch_shape
236+ else :
237+ fill_shape = patch_shape
297238
298- # Calculate region height and width in pixel space
299- pixel_h = int (round (math .sqrt (target_area * aspect_ratio )))
300- pixel_w = int (round (math .sqrt (target_area / aspect_ratio )))
301-
302- # Ensure region fits within total pixel grid
303- if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h :
304- # Select random top-left corner in pixel space
305- pixel_top = random .randint (0 , grid_h * patch_h - pixel_h )
306- pixel_left = random .randint (0 , grid_w * patch_w - pixel_w )
307-
308- # Define region bounds in pixel space
309- pixel_bottom = pixel_top + pixel_h
310- pixel_right = pixel_left + pixel_w
311-
312- # Create a single random value for the entire region if using 'rand' mode
313- rand_value = None
314- if self .erase_mode == 'rand' :
315- rand_value = torch .empty (patch_shape [- 1 ], dtype = dtype , device = self .device ).normal_ ()
316-
317- # For each valid patch, determine if and how it overlaps with the erase region
318- for i in range (len (patches )):
319- if patch_valid is None or patch_valid [i ]:
320- # Convert patch coordinates to pixel space (top-left corner)
321- y , x = patch_coord [i ]
322- patch_pixel_top = y * patch_h
323- patch_pixel_left = x * patch_w
324- patch_pixel_bottom = patch_pixel_top + patch_h
325- patch_pixel_right = patch_pixel_left + patch_w
326-
327- # Check if this patch overlaps with the erase region
328- if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
329- patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom ):
330-
331- # Calculate the overlap region in patch-local coordinates
332- local_top = max (0 , pixel_top - patch_pixel_top )
333- local_left = max (0 , pixel_left - patch_pixel_left )
334- local_bottom = min (patch_h , pixel_bottom - patch_pixel_top )
335- local_right = min (patch_w , pixel_right - patch_pixel_left )
336-
337- # Reshape the patch to [patch_h, patch_w, chans]
338- patch_data = patches [i ].reshape (patch_h , patch_w , channels )
339-
340- erase_shape = (local_bottom - local_top , local_right - local_left , channels )
341- erase_value = self ._get_values (erase_shape , dtype = dtype , value = rand_value )
342- patch_data [local_top :local_bottom , local_left :local_right , :] = erase_value
343-
344- # Flatten the patch back to [patch_h*patch_w, chans]
345- if len (patch_shape ) == 2 :
346- patch_data = patch_data .reshape (- 1 , channels )
347- patches [i ] = patch_data
348-
349- # Successfully applied erasing, exit the loop
350- break
239+ patches [region_mask ] = self ._get_values (fill_shape , dtype = dtype )
240+ # Successfully applied erasing, exit the loop
241+ break
351242
352243 def __call__ (
353244 self ,
@@ -369,18 +260,12 @@ def __call__(
369260 """
370261 if patches .ndim == 4 :
371262 batch_size , num_patches , patch_dim , channels = patches .shape
372- if self .patch_size is not None :
373- patch_size = self .patch_size
374- else :
375- patch_size = None
376263 elif patches .ndim == 5 :
377264 batch_size , num_patches , patch_h , patch_w , channels = patches .shape
378- patch_size = (patch_h , patch_w )
379265 else :
380266 assert False
381267 patch_shape = patches .shape [2 :]
382268 # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
383- # patch_size ==> patch h, w (if available, must be avail for subregion mode)
384269
385270 # Create default valid mask if not provided
386271 if patch_valid is None :
@@ -399,6 +284,7 @@ def __call__(
399284 patch_valid [i ],
400285 )
401286 elif self .spatial_mode == 'patch' :
287+ # FIXME we could vectorize patch mode across batch, worth the effort?
402288 self ._erase_patches (
403289 patches [i ],
404290 patch_coord [i ],
@@ -414,15 +300,8 @@ def __call__(
414300 patch_shape ,
415301 patches .dtype
416302 )
417- elif self .spatial_mode == 'subregion' :
418- self ._erase_subregion (
419- patches [i ],
420- patch_coord [i ],
421- patch_valid [i ],
422- patch_shape ,
423- patch_size ,
424- patches .dtype
425- )
303+ else :
304+ assert False
426305
427306 return patches
428307
0 commit comments