Skip to content

Commit fc20d04

Browse files
anna-grimanna-grim
andauthored
Feat merge inference (#465)
* feat: graph traversal and batch generation * refactor: threaded merge inference * feat: implemeneted traversal step size --------- Co-authored-by: anna-grim <[email protected]>
1 parent b7c6e30 commit fc20d04

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/deep_neurographs/merge_proofreading/inference.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
self.anisotropy = anisotropy
104104
self.batch_size = batch_size
105105
self.prefetch = prefetch
106-
self.traversal_step = traversal_step # not implemented
106+
self.traversal_step = traversal_step
107107

108108
# Image reader
109109
self.img_reader = img_util.init_reader(img_path)
@@ -187,6 +187,7 @@ def _generate_batch_metadata_for_component(self, root):
187187
# Check if starting new batch
188188
if len(patch_centers) == 0:
189189
root = i
190+
last_node = i
190191
node_ids.append(i)
191192
patch_centers.append(self.get_voxel(i))
192193
visited.add(i)
@@ -201,11 +202,15 @@ def _generate_batch_metadata_for_component(self, root):
201202

202203
# Visit j
203204
if j not in visited:
204-
node_ids.append(j)
205-
patch_centers.append(self.get_voxel(j))
206205
visited.add(j)
207-
if len(patch_centers) == 1:
208-
root = j
206+
is_next = self.graph.dist(last_node, j) >= self.traversal_step
207+
is_branching = self.graph.degree[j] >= 3
208+
if is_next or is_branching:
209+
last_node = j
210+
node_ids.append(j)
211+
patch_centers.append(self.get_voxel(j))
212+
if len(patch_centers) == 1:
213+
root = j
209214

210215
# Yield any remaining nodes after the loop
211216
if patch_centers:
@@ -216,6 +221,10 @@ def get_batch(self, superchunk, patch_centers, node_ids):
216221
for i, center in enumerate(patch_centers):
217222
s = img_util.get_slices(center, self.patch_shape)
218223
batch[i, 0, ...] = superchunk[s]
224+
superchunk[tuple(center)] = 1.5
225+
from tifffile import imwrite
226+
imwrite("superchunk.tif", superchunk)
227+
print("superchunk.shape:", superchunk.shape)
219228
return node_ids, torch.tensor(batch, dtype=torch.float)
220229

221230
# --- Helpers ---

0 commit comments

Comments
 (0)