Skip to content

Commit 0c3a1ec

Browse files
anna-grimanna-grim
andauthored
Bug val dataset (#558)
* bug: val dataset length * bug: save mistake mips multimodal * bug: query error * refactor: mkdir * refactor: updated merge inference --------- Co-authored-by: anna-grim <[email protected]>
1 parent 02de6b0 commit 0c3a1ec

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

src/neuron_proofreader/merge_proofreading/merge_inference.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import numpy as np
2020
import torch
2121

22-
from neuron_proofreader import exp
22+
from neuron_proofreader.machine_learning.point_cloud_models import (
23+
subgraph_to_point_cloud,
24+
)
2325
from neuron_proofreader.utils import img_util, ml_util, util
2426

2527

@@ -35,6 +37,7 @@ def __init__(
3537
anisotropy=(1.0, 1.0, 1.0),
3638
batch_size=32,
3739
device="cuda",
40+
is_multimodal=False,
3841
min_size=0,
3942
prefetch=128,
4043
remove_detected_sites=False,
@@ -62,6 +65,7 @@ def __init__(
6265
patch_shape,
6366
anisotropy=anisotropy,
6467
batch_size=batch_size,
68+
is_multimodal=is_multimodal,
6569
min_size=min_size,
6670
prefetch=prefetch,
6771
step_size=step_size,
@@ -173,6 +177,7 @@ def __init__(
173177
patch_shape,
174178
anisotropy=(1.0, 1.0, 1.0),
175179
batch_size=16,
180+
is_multimodal=False,
176181
min_size=0,
177182
prefetch=128,
178183
step_size=10,
@@ -186,14 +191,15 @@ def __init__(
186191
self.batch_size = batch_size
187192
self.distance_traversed = 0
188193
self.graph = graph
194+
self.is_multimodal = is_multimodal
189195
self.min_size = min_size
190196
self.patch_shape = patch_shape
191197
self.prefetch = prefetch
192198
self.step_size = step_size
193199
self.subgraph_radius = subgraph_radius
194200

195201
# Image reader
196-
self.img_reader = img_util.init_reader(img_path)
202+
self.img_reader = img_util.TensorStoreReader(img_path)
197203

198204
# --- Core routines ---
199205
def __iter__(self):
@@ -222,7 +228,14 @@ def submit_thread():
222228
# Process completed thread
223229
nodes, patch_centers = pending.pop(thread)
224230
img, offset = thread.result()
225-
yield self.get_multimodal_batch(img, offset, patch_centers, nodes)
231+
if self.is_multimodal:
232+
yield self.get_multimodal_batch(
233+
img, offset, patch_centers, nodes
234+
)
235+
else:
236+
yield self.get_batch(
237+
img, offset, patch_centers, nodes
238+
)
226239

227240
# Continue submitting threads
228241
submit_thread()
@@ -338,10 +351,10 @@ def get_multimodal_batch(self, img, offset, patch_centers, nodes):
338351
patches[i, 1, ...] = label_mask[s]
339352

340353
subgraph = self.graph.get_rooted_subgraph(node, self.subgraph_radius)
341-
point_clouds[i] = exp.subgraph_to_point_cloud(subgraph)
354+
point_clouds[i] = subgraph_to_point_cloud(subgraph)
342355

343356
# Compile batch dictionary
344-
batch = exp.TensorDict({
357+
batch = ml_util.TensorDict({
345358
"img": ml_util.to_tensor(patches),
346359
"point_cloud": ml_util.to_tensor(point_clouds)
347360
})

src/neuron_proofreader/utils/util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ def listdir(path, extension=None):
4040
4141
Returns
4242
-------
43-
List[str]
43+
filenames : List[str]
4444
Filenames in directory with extension "extension" if provided.
4545
Otherwise, list of all files in directory.
4646
"""
47-
if extension is None:
48-
return [f for f in os.listdir(path)]
47+
filenames = [f for f in os.listdir(path) if not f.startswith(".")]
48+
if extension:
49+
return [f for f in filenames if f.endswith(extension)]
4950
else:
50-
return [f for f in os.listdir(path) if f.endswith(extension)]
51+
return filenames
5152

5253

5354
def list_files_in_zip(zip_content):
@@ -138,8 +139,7 @@ def mkdir(path, delete=False):
138139
if delete:
139140
rmdir(path)
140141

141-
if not os.path.exists(path):
142-
os.mkdir(path)
142+
os.makedirs(path, exist_ok=True)
143143

144144

145145
def rmdir(path):

0 commit comments

Comments
 (0)