Skip to content

Commit 02de6b0

Browse files
anna-grimanna-grim
andauthored
Refactor more augmentation (#557)
* refactor: more aggressive image augmentation * refactor: more aggressive image augmentation --------- Co-authored-by: anna-grim <[email protected]>
1 parent 574b6f1 commit 02de6b0

File tree

3 files changed

+12
-35
lines changed

3 files changed

+12
-35
lines changed

src/neuron_proofreader/machine_learning/augmentation.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,11 @@ def __call__(self, patches):
8484
patches : numpy.ndarray
8585
Image with the shape (2, H, W, D), where "patches[0, ...]" is from
8686
the input image and "patches[1, ...]" is from the segmentation.
87-
88-
Returns
89-
-------
90-
patches : numpy.ndarray
91-
Flipped 3D image and segmentation patch.
9287
"""
9388
for axis in self.axes:
9489
if random.random() > 0.5:
9590
patches[0, ...] = np.flip(patches[0, ...], axis=axis)
9691
patches[1, ...] = np.flip(patches[1, ...], axis=axis)
97-
return patches
9892

9993

10094
class RandomRotation3D:
@@ -125,18 +119,12 @@ def __call__(self, patches):
125119
patches : numpy.ndarray
126120
Image with the shape (2, H, W, D), where "patches[0, ...]" is from
127121
the input image and "patches[1, ...]" is from the segmentation.
128-
129-
Returns
130-
-------
131-
patches : numpy.ndarray
132-
Rotated 3D image and segmentation patch.
133122
"""
134123
for axes in self.axes:
135124
if random.random() < 0.5:
136125
angle = random.uniform(*self.angles)
137126
patches[0, ...] = rotate3d(patches[0, ...], angle, axes)
138127
patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True)
139-
return patches
140128

141129

142130
class RandomScale3D:
@@ -216,30 +204,25 @@ def __call__(self, img_patch):
216204
----------
217205
img_patch : numpy.ndarray
218206
Image to which contrast will be added.
219-
220-
Returns
221-
-------
222-
numpy.ndarray
223-
Contrasted 3D image.
224207
"""
225208
factor = random.uniform(*self.factor_range)
226-
return np.clip(img_patch * factor, 0, 1)
209+
img_patch = np.clip(img_patch * factor, 0, 1)
227210

228211

229212
class RandomNoise3D:
230213
"""
231214
Adds random Gaussian noise to a 3D image.
232215
"""
233216

234-
def __init__(self, max_std=0.2):
217+
def __init__(self, max_std=0.3):
235218
"""
236219
Initializes a RandomNoise3D transformer.
237220
238221
Parameters
239222
----------
240223
max_std : float, optional
241224
Maximum standard deviation of the Gaussian noise distribution.
242-
Default is 0.16.
225+
Default is 0.3.
243226
"""
244227
self.max_std = max_std
245228

@@ -251,16 +234,10 @@ def __call__(self, img_patch):
251234
----------
252235
img_patch : np.ndarray
253236
Image to which noise will be added.
254-
255-
Returns
256-
-------
257-
img_patch : numpy.ndarray
258-
Noisy 3D image.
259237
"""
260238
std = self.max_std * random.random()
261-
noise = np.random.uniform(-std, std, img_patch.shape)
262-
img_patch += noise
263-
return np.clip(img_patch, 0, 1)
239+
img_patch += np.random.uniform(-std, std, img_patch.shape)
240+
img_patch = np.clip(img_patch, 0, 1)
264241

265242

266243
# --- Helpers ---

src/neuron_proofreader/merge_proofreading/merge_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ def get_img_patch(self, brain_id, center):
446446
-------
447447
img_patch : numpy.ndarray
448448
Extracted image patch, which has been normalized and clipped to a
449-
maximum value of 300.
449+
maximum value of 400.
450450
"""
451451
img_patch = self.img_readers[brain_id].read(center, self.patch_shape)
452-
img_patch = img_util.normalize(np.minimum(img_patch, 300))
452+
img_patch = img_util.normalize(np.minimum(img_patch, 400))
453453
return img_patch
454454

455455
def get_segment_mask(self, subgraph):
@@ -953,11 +953,11 @@ def _load_multimodal_batch(self, batch_idxs):
953953
# Store results
954954
patches = np.zeros((len(batch_idxs),) + self.patches_shape)
955955
labels = np.zeros((len(batch_idxs), 1))
956-
point_clouds = np.zeros((len(batch_idxs), 3, 3600))
956+
point_clouds = np.zeros((len(batch_idxs), 3600, 3))
957957
for thread in as_completed(pending.keys()):
958958
i = pending.pop(thread)
959959
patches[i], subgraph, labels[i] = thread.result()
960-
point_clouds[i] = subgraph_to_point_cloud(subgraph) #.T
960+
point_clouds[i] = subgraph_to_point_cloud(subgraph).T
961961

962962
# Set batch dictionary
963963
batch = ml_util.TensorDict(

src/neuron_proofreader/merge_proofreading/merge_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def get_batch(self, img, offset, patch_centers, nodes):
319319
batch = np.empty((len(patch_centers), 2,) + self.patch_shape)
320320
for i, center in enumerate(patch_centers):
321321
s = img_util.get_slices(center, self.patch_shape)
322-
batch[i, 0, ...] = img_util.normalize(np.minimum(img[s], 300))
322+
batch[i, 0, ...] = img_util.normalize(np.minimum(img[s], 400))
323323
batch[i, 1, ...] = label_mask[s]
324324
return nodes, torch.tensor(batch, dtype=torch.float)
325325

@@ -331,10 +331,10 @@ def get_multimodal_batch(self, img, offset, patch_centers, nodes):
331331

332332
# Populate batch array
333333
patches = np.empty((batch_size, 2,) + self.patch_shape)
334-
point_clouds = np.empty((batch_size, 3, 3200), dtype=np.float32)
334+
point_clouds = np.empty((batch_size, 3, 3600), dtype=np.float32)
335335
for i, (center, node) in enumerate(zip(patch_centers, nodes)):
336336
s = img_util.get_slices(center, self.patch_shape)
337-
patches[i, 0, ...] = img_util.normalize(np.minimum(img[s], 300))
337+
patches[i, 0, ...] = img_util.normalize(np.minimum(img[s], 400))
338338
patches[i, 1, ...] = label_mask[s]
339339

340340
subgraph = self.graph.get_rooted_subgraph(node, self.subgraph_radius)

0 commit comments

Comments
 (0)