@@ -103,7 +103,7 @@ def __init__(
103103 self .anisotropy = anisotropy
104104 self .batch_size = batch_size
105105 self .prefetch = prefetch
106- self .traversal_step = traversal_step # not implemented
106+ self .traversal_step = traversal_step
107107
108108 # Image reader
109109 self .img_reader = img_util .init_reader (img_path )
@@ -187,6 +187,7 @@ def _generate_batch_metadata_for_component(self, root):
187187 # Check if starting new batch
188188 if len (patch_centers ) == 0 :
189189 root = i
190+ last_node = i
190191 node_ids .append (i )
191192 patch_centers .append (self .get_voxel (i ))
192193 visited .add (i )
@@ -201,11 +202,15 @@ def _generate_batch_metadata_for_component(self, root):
201202
202203 # Visit j
203204 if j not in visited :
204- node_ids .append (j )
205- patch_centers .append (self .get_voxel (j ))
206205 visited .add (j )
207- if len (patch_centers ) == 1 :
208- root = j
206+ is_next = self .graph .dist (last_node , j ) >= self .traversal_step
207+ is_branching = self .graph .degree [j ] >= 3
208+ if is_next or is_branching :
209+ last_node = j
210+ node_ids .append (j )
211+ patch_centers .append (self .get_voxel (j ))
212+ if len (patch_centers ) == 1 :
213+ root = j
209214
210215 # Yield any remaining nodes after the loop
211216 if patch_centers :
@@ -216,6 +221,10 @@ def get_batch(self, superchunk, patch_centers, node_ids):
216221 for i , center in enumerate (patch_centers ):
217222 s = img_util .get_slices (center , self .patch_shape )
218223 batch [i , 0 , ...] = superchunk [s ]
224+ superchunk [tuple (center )] = 1.5
225+ from tifffile import imwrite
226+ imwrite ("superchunk.tif" , superchunk )
227+ print ("superchunk.shape:" , superchunk .shape )
219228 return node_ids , torch .tensor (batch , dtype = torch .float )
220229
221230 # --- Helpers ---
0 commit comments