Skip to content

Commit f001b15

Browse files
committed
NaFlex random erasing performance improvements, python loops were slow. Remove subregion mode, not going to be worth it.
1 parent 7624389 commit f001b15

File tree

1 file changed

+48
-169
lines changed

1 file changed

+48
-169
lines changed

timm/data/naflex_random_erasing.py

Lines changed: 48 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,15 @@ class PatchRandomErasing:
1111
1212
Supports three modes:
1313
1. 'patch': Simple mode that erases randomly selected valid patches
14-
2. 'region': Erases spatial regions at patch granularity
15-
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
16-
partially erasing patches that are on the boundary of the erased region
14+
2. 'region': Erases rectangular regions at patch granularity
1715
1816
Args:
1917
erase_prob: Probability that the Random Erasing operation will be performed.
2018
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
2119
min_area: Minimum percentage of valid patches/area to erase.
2220
max_area: Maximum percentage of valid patches/area to erase.
23-
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
24-
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
21+
min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
22+
max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
2523
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
2624
'const' - erase patch is constant color of 0 for all channels
2725
'rand' - erase patch has same random (normal) value across all elements
@@ -45,7 +43,6 @@ def __init__(
4543
mode: str = 'const',
4644
value: float = 0.,
4745
spatial_mode: str = 'region',
48-
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
4946
num_splits: int = 0,
5047
device: Union[str, torch.device] = 'cuda',
5148
):
@@ -66,14 +63,13 @@ def __init__(
6663

6764
# Strategy mode
6865
self.spatial_mode = spatial_mode
69-
70-
# Patch size (needed for subregion mode)
71-
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
66+
assert self.spatial_mode in ('patch', 'region')
7267

7368
# Value generation mode flags
7469
self.erase_mode = mode.lower()
7570
assert self.erase_mode in ('rand', 'pixel', 'const')
7671
self.const_value = value
72+
self.unique_noise_per_patch = True
7773

7874
def _get_values(
7975
self,
@@ -156,27 +152,27 @@ def _erase_patches(
156152
return
157153

158154
# Get indices of valid patches
159-
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
160-
if not valid_indices:
161-
# Skip if no valid patches
155+
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
156+
num_valid = len(valid_indices)
157+
if num_valid == 0:
162158
return
163159

164-
num_valid = len(valid_indices)
165160
count = random.randint(self.min_count, self.max_count)
166161
# Determine how many valid patches to erase from RE min/max count and area args
167-
max_erase = max(1, int(num_valid * count * self.max_area))
162+
max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
168163
min_erase = max(1, int(num_valid * count * self.min_area))
169164
num_erase = random.randint(min_erase, max_erase)
170165

171166
# Randomly select valid patches to erase
172-
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
167+
erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
173168

174-
random_value = None
175-
if self.erase_mode == 'rand':
176-
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
169+
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
170+
# generate unique noise for the whole selection of patches
171+
fill_shape = (num_erase,) + patch_shape
172+
else:
173+
fill_shape = patch_shape
177174

178-
for idx in indices_to_erase:
179-
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
175+
patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
180176

181177
def _erase_region(
182178
self,
@@ -195,20 +191,14 @@ def _erase_region(
195191
return
196192

197193
# Determine grid dimensions from coordinates
198-
if patch_valid is not None:
199-
valid_coord = patch_coord[patch_valid]
200-
if len(valid_coord) == 0:
201-
return # No valid patches
202-
max_y = valid_coord[:, 0].max().item() + 1
203-
max_x = valid_coord[:, 1].max().item() + 1
204-
else:
205-
max_y = patch_coord[:, 0].max().item() + 1
206-
max_x = patch_coord[:, 1].max().item() + 1
207-
194+
valid_coord = patch_coord[patch_valid]
195+
if len(valid_coord) == 0:
196+
return # No valid patches
197+
max_y = valid_coord[:, 0].max().item() + 1
198+
max_x = valid_coord[:, 1].max().item() + 1
208199
grid_h, grid_w = max_y, max_x
209-
210-
# Calculate total area
211200
total_area = grid_h * grid_w
201+
ys, xs = patch_coord[:, 0], patch_coord[:, 1]
212202

213203
count = random.randint(self.min_count, self.max_count)
214204
for _ in range(count):
@@ -222,132 +212,33 @@ def _erase_region(
222212
h = int(round(math.sqrt(target_area * aspect_ratio)))
223213
w = int(round(math.sqrt(target_area / aspect_ratio)))
224214

225-
# Ensure region fits within grid
226-
if w <= grid_w and h <= grid_h:
227-
# Select random top-left corner
228-
top = random.randint(0, grid_h - h)
229-
left = random.randint(0, grid_w - w)
230-
231-
# Define region bounds
232-
bottom = top + h
233-
right = left + w
234-
235-
# Create a single random value for all affected patches if using 'rand' mode
236-
if self.erase_mode == 'rand':
237-
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
238-
else:
239-
random_value = None
240-
241-
# Find and erase all patches that fall within the region
242-
for i in range(len(patches)):
243-
if patch_valid is None or patch_valid[i]:
244-
y, x = patch_coord[i]
245-
if top <= y < bottom and left <= x < right:
246-
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
247-
248-
# Successfully applied erasing, exit the loop
249-
break
250-
251-
def _erase_subregion(
252-
self,
253-
patches: torch.Tensor,
254-
patch_coord: torch.Tensor,
255-
patch_valid: torch.Tensor,
256-
patch_shape: torch.Size,
257-
patch_size: Tuple[int, int],
258-
dtype: torch.dtype = torch.float32,
259-
):
260-
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
215+
if h > grid_h or w > grid_w:
216+
continue # try again
261217

262-
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
263-
regions that are not aligned to patches (erase regions boundaries cut within patches).
218+
# Calculate region patch bounds
219+
top = random.randint(0, grid_h - h)
220+
left = random.randint(0, grid_w - w)
221+
bottom, right = top + h, left + w
264222

265-
FIXME complexity probably not worth it, may remove.
266-
"""
267-
if random.random() > self.erase_prob:
268-
return
269-
270-
# Get patch dimensions
271-
patch_h, patch_w = patch_size
272-
channels = patch_shape[-1]
273-
274-
# Determine grid dimensions in patch coordinates
275-
if patch_valid is not None:
276-
valid_coord = patch_coord[patch_valid]
277-
if len(valid_coord) == 0:
278-
return # No valid patches
279-
max_y = valid_coord[:, 0].max().item() + 1
280-
max_x = valid_coord[:, 1].max().item() + 1
281-
else:
282-
max_y = patch_coord[:, 0].max().item() + 1
283-
max_x = patch_coord[:, 1].max().item() + 1
284-
285-
grid_h, grid_w = max_y, max_x
286-
287-
# Calculate total area in pixel space
288-
total_area = (grid_h * patch_h) * (grid_w * patch_w)
223+
# Region test
224+
region_mask = (
225+
(ys >= top) & (ys < bottom) &
226+
(xs >= left) & (xs < right) &
227+
patch_valid
228+
)
229+
num_selected = int(region_mask.sum().item())
230+
if not num_selected:
231+
continue # no patch actually falls inside – try again
289232

290-
count = random.randint(self.min_count, self.max_count)
291-
for _ in range(count):
292-
# Try to select a valid region to erase (multiple attempts)
293-
for attempt in range(10):
294-
# Sample random area and aspect ratio
295-
target_area = random.uniform(self.min_area, self.max_area) * total_area
296-
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
233+
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
234+
# generate unique noise for the whole region
235+
fill_shape = (num_selected,) + patch_shape
236+
else:
237+
fill_shape = patch_shape
297238

298-
# Calculate region height and width in pixel space
299-
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
300-
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
301-
302-
# Ensure region fits within total pixel grid
303-
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
304-
# Select random top-left corner in pixel space
305-
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
306-
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
307-
308-
# Define region bounds in pixel space
309-
pixel_bottom = pixel_top + pixel_h
310-
pixel_right = pixel_left + pixel_w
311-
312-
# Create a single random value for the entire region if using 'rand' mode
313-
rand_value = None
314-
if self.erase_mode == 'rand':
315-
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
316-
317-
# For each valid patch, determine if and how it overlaps with the erase region
318-
for i in range(len(patches)):
319-
if patch_valid is None or patch_valid[i]:
320-
# Convert patch coordinates to pixel space (top-left corner)
321-
y, x = patch_coord[i]
322-
patch_pixel_top = y * patch_h
323-
patch_pixel_left = x * patch_w
324-
patch_pixel_bottom = patch_pixel_top + patch_h
325-
patch_pixel_right = patch_pixel_left + patch_w
326-
327-
# Check if this patch overlaps with the erase region
328-
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
329-
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
330-
331-
# Calculate the overlap region in patch-local coordinates
332-
local_top = max(0, pixel_top - patch_pixel_top)
333-
local_left = max(0, pixel_left - patch_pixel_left)
334-
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
335-
local_right = min(patch_w, pixel_right - patch_pixel_left)
336-
337-
# Reshape the patch to [patch_h, patch_w, chans]
338-
patch_data = patches[i].reshape(patch_h, patch_w, channels)
339-
340-
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
341-
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
342-
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
343-
344-
# Flatten the patch back to [patch_h*patch_w, chans]
345-
if len(patch_shape) == 2:
346-
patch_data = patch_data.reshape(-1, channels)
347-
patches[i] = patch_data
348-
349-
# Successfully applied erasing, exit the loop
350-
break
239+
patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
240+
# Successfully applied erasing, exit the loop
241+
break
351242

352243
def __call__(
353244
self,
@@ -369,18 +260,12 @@ def __call__(
369260
"""
370261
if patches.ndim == 4:
371262
batch_size, num_patches, patch_dim, channels = patches.shape
372-
if self.patch_size is not None:
373-
patch_size = self.patch_size
374-
else:
375-
patch_size = None
376263
elif patches.ndim == 5:
377264
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
378-
patch_size = (patch_h, patch_w)
379265
else:
380266
assert False
381267
patch_shape = patches.shape[2:]
382268
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
383-
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
384269

385270
# Create default valid mask if not provided
386271
if patch_valid is None:
@@ -399,6 +284,7 @@ def __call__(
399284
patch_valid[i],
400285
)
401286
elif self.spatial_mode == 'patch':
287+
# FIXME we could vectorize patch mode across batch, worth the effort?
402288
self._erase_patches(
403289
patches[i],
404290
patch_coord[i],
@@ -414,15 +300,8 @@ def __call__(
414300
patch_shape,
415301
patches.dtype
416302
)
417-
elif self.spatial_mode == 'subregion':
418-
self._erase_subregion(
419-
patches[i],
420-
patch_coord[i],
421-
patch_valid[i],
422-
patch_shape,
423-
patch_size,
424-
patches.dtype
425-
)
303+
else:
304+
assert False
426305

427306
return patches
428307

0 commit comments

Comments
 (0)