Skip to content

Commit a265526

Browse files
anna-grimanna-grim
andauthored
bug: fixed missing fragments in merge datasets (#561)
Co-authored-by: anna-grim <[email protected]>
1 parent 858a6df commit a265526

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

src/neuron_proofreader/merge_proofreading/merge_dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import pandas as pd
1616

17-
TEST_BRAIN = "685221"
17+
TEST_BRAIN = "653159"
1818

1919

2020
# --- Load Skeletons ---

src/neuron_proofreader/merge_proofreading/merge_datasets.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
geometry_util,
3333
img_util,
3434
ml_util,
35-
swc_util,
3635
util,
3736
)
3837

@@ -229,16 +228,19 @@ def remove_nonindexed_fragments(self, idxs):
229228
other sites are removed.
230229
"""
231230
# Remove other fragments
232-
for idx in [i for i in self.merge_sites_df.index if i not in idxs]:
231+
visited = set()
232+
for i in [i for i in self.merge_sites_df.index if i not in idxs]:
233233
# Extract site info
234-
brain_id = self.merge_sites_df["brain_id"][idx]
235-
xyz = self.merge_sites_df["xyz"][idx]
234+
brain_id = self.merge_sites_df["brain_id"][i]
235+
segment_id = self.merge_sites_df["segment_id"][i]
236+
pair = (brain_id, segment_id)
236237

237238
# Find fragment containing site
238-
dist, node = self.graphs[brain_id].kdtree.query(xyz)
239-
if dist < 20 and node in self.graphs[brain_id]:
240-
nodes = self.graphs[brain_id].get_connected_nodes(node)
239+
if pair not in visited:
240+
nodes = self.graphs[brain_id].get_nodes_with_segment_id(segment_id)
241241
self.graphs[brain_id].remove_nodes(nodes, False)
242+
visited.add(pair)
243+
242244
self.remove_empty_graphs()
243245

244246
# Relabel nodes
@@ -254,7 +256,7 @@ def remove_empty_graphs(self):
254256
Removes graphs without any nodes.
255257
"""
256258
for brain_id in list(self.graphs.keys()):
257-
if len(self.graphs[brain_id]) == 0:
259+
if len(self.graphs[brain_id].nodes) == 0:
258260
del self.graphs[brain_id]
259261

260262
# --- Getters ---

src/neuron_proofreader/skeleton_graph.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,24 @@ def get_nodes_with_component_id(self, component_id):
376376
"""
377377
return set(np.where(self.node_component_id == component_id)[0])
378378

379+
def get_nodes_with_segment_id(self, segment_id):
380+
nodes = set()
381+
query_id = f"{segment_id}."
382+
for swc_id in self.get_swc_ids():
383+
segment_id = int(swc_id.replace(".0", ""))
384+
if segment_id == query_id:
385+
component_id = self.get_component_id_from_swc_id(swc_id)
386+
nodes = nodes.union(
387+
self.get_nodes_with_component_id(component_id)
388+
)
389+
return nodes
390+
391+
def get_component_id_from_swc_id(self, query_swc_id):
392+
for component_id, swc_id in self.component_id_to_swc_id.items():
393+
if query_swc_id == swc_id:
394+
return component_id
395+
raise ValueError(f"SWC ID={query_swc_id} not found")
396+
379397
def get_rooted_subgraph(self, root, radius):
380398
"""
381399
Gets a rooted subgraph with the given radius (in microns).

src/neuron_proofreader/utils/geometry_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from collections import defaultdict
12-
from scipy.interpolate import splprep, splev, UnivariateSpline
12+
from scipy.interpolate import UnivariateSpline
1313
from scipy.linalg import svd
1414
from scipy.spatial import distance
1515
from tqdm import tqdm

0 commit comments

Comments
 (0)