Skip to content

Commit 2a130df

Browse files
anna-grimanna-grim
andauthored
Feat merge detection (#441)
* refactor: memory issue in spline fit * bug: branch resampling: * refactor: simplified irreducible extractoin * graph-level node radii and xyz * refactor: removed node swc_id attr * refactor: minor upds --------- Co-authored-by: anna-grim <[email protected]>
1 parent e5a08f1 commit 2a130df

File tree

5 files changed

+49
-74
lines changed

5 files changed

+49
-74
lines changed

src/deep_neurographs/fragments_graph.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
2727
3. Add Irreducibles
2828
to do...
29-
3029
"""
3130

3231
from concurrent.futures import as_completed, ThreadPoolExecutor
@@ -41,10 +40,7 @@
4140

4241
from deep_neurographs import proposal_generation
4342
from deep_neurographs.utils import (
44-
geometry_util as geometry,
45-
graph_util as gutil,
46-
swc_util,
47-
util,
43+
geometry_util as geometry, graph_util as gutil, util,
4844
)
4945
from deep_neurographs.machine_learning import groundtruth_generation
5046

@@ -106,7 +102,6 @@ def __init__(
106102
Returns
107103
-------
108104
None
109-
110105
"""
111106
# Call parent class
112107
super(FragmentsGraph, self).__init__()
@@ -153,7 +148,6 @@ def load_fragments(self, fragments_pointer):
153148
Returns
154149
-------
155150
None
156-
157151
"""
158152
# Extract irreducible components from SWC files
159153
irreducibles = self.graph_loader.run(fragments_pointer)
@@ -188,7 +182,6 @@ def add_irreducibles(self, irreducibles, component_id):
188182
Returns
189183
-------
190184
None
191-
192185
"""
193186
# SWC ID
194187
self.component_id_to_swc_id[component_id] = irreducibles["swc_id"]
@@ -203,7 +196,7 @@ def add_irreducibles(self, irreducibles, component_id):
203196

204197
def _add_nodes(self, node_dict, component_id):
205198
"""
206-
Adds nodes to the graph from a dictionary of node attributes and
199+
Adds nodes to the graph from a dictionary of node attributes and
207200
returns a mapping from original node IDs to the new graph node IDs.
208201
209202
Parameters
@@ -247,7 +240,6 @@ def _add_edge(self, edge, attrs):
247240
Returns
248241
-------
249242
None
250-
251243
"""
252244
i, j = tuple(edge)
253245
self.add_edge(i, j, radius=attrs["radius"], xyz=attrs["xyz"])
@@ -293,7 +285,6 @@ def init_kdtree(self, node_type=None):
293285
Returns
294286
-------
295287
None
296-
297288
"""
298289
if node_type == "leaf":
299290
self.leaf_kdtree = self.get_kdtree(node_type=node_type)
@@ -315,7 +306,6 @@ def get_kdtree(self, node_type=None):
315306
-------
316307
KDTree
317308
KD-Tree generated from xyz coordinates across all nodes and edges.
318-
319309
"""
320310
# Get xyz coordinates
321311
if node_type == "leaf":
@@ -344,7 +334,6 @@ def query_kdtree(self, xyz, d, node_type=None):
344334
generator[tuple]
345335
Generator that generates the xyz coordinates cooresponding to all
346336
nodes within a distance of "d" from "xyz".
347-
348337
"""
349338
if node_type == "leaf":
350339
return geometry.query_ball(self.leaf_kdtree, xyz, d)
@@ -387,7 +376,6 @@ def generate_proposals(
387376
Returns
388377
-------
389378
None
390-
391379
"""
392380
# Initializations
393381
self.reset_proposals()
@@ -425,7 +413,6 @@ def reset_proposals(self):
425413
Returns
426414
-------
427415
None
428-
429416
"""
430417
self.proposals = set()
431418
for i in self.nodes:
@@ -461,7 +448,6 @@ def add_proposal(self, i, j):
461448
Returns
462449
-------
463450
None
464-
465451
"""
466452
proposal = frozenset({i, j})
467453
self.nodes[i]["proposals"].add(j)
@@ -475,12 +461,11 @@ def remove_proposal(self, proposal):
475461
Parameters
476462
----------
477463
proposal : Frozenset[int]
478-
Pair of node ids corresponding to a proposal.
464+
Pair of node IDs corresponding to a proposal.
479465
480466
Returns
481467
-------
482468
None
483-
484469
"""
485470
i, j = tuple(proposal)
486471
self.nodes[i]["proposals"].remove(j)
@@ -495,14 +480,13 @@ def is_single_proposal(self, proposal):
495480
Parameters
496481
----------
497482
proposal : Frozenset[int]
498-
Pair of node ids corresponding to a proposal.
483+
Pair of node IDs corresponding to a proposal.
499484
500485
Returns
501486
-------
502487
bool
503488
Indiciation of "proposal" is the only proposal generated for the
504-
corresponding nodes.
505-
489+
corresponding nodes.
506490
"""
507491
i, j = tuple(proposal)
508492
single_i = len(self.nodes[i]["proposals"]) == 1
@@ -529,7 +513,6 @@ def is_valid_proposal(self, leaf, i, complex_bool):
529513
-------
530514
bool
531515
Indication of whether proposal is valid.
532-
533516
"""
534517
if i is not None:
535518
skip_soma = self.is_soma(i) and self.is_soma(leaf)
@@ -549,9 +532,8 @@ def list_proposals(self):
549532
550533
Returns
551534
-------
552-
list
535+
List[Frozenset[int]]
553536
Proposals.
554-
555537
"""
556538
return list(self.proposals)
557539

@@ -568,13 +550,12 @@ def n_proposals(self):
568550
-------
569551
int
570552
Number of proposals in the graph.
571-
572553
"""
573554
return len(self.proposals)
574555

575556
def is_simple(self, proposal):
576557
"""
577-
Determines whether both nodes in a proposal are leafs.
558+
Checks if both nodes in a proposal are leafs.
578559
579560
Parameters
580561
----------
@@ -585,7 +566,6 @@ def is_simple(self, proposal):
585566
-------
586567
bool
587568
Indication of whether both nodes in a proposal are leafs.
588-
589569
"""
590570
i, j = tuple(proposal)
591571
return True if self.degree[i] == 1 and self.degree[j] == 1 else False
@@ -618,7 +598,6 @@ def proposal_attr(self, proposal, key):
618598
-------
619599
numpy.ndarray
620600
Attributes of nodes in "proposal".
621-
622601
"""
623602
i, j = tuple(proposal)
624603
if key == "xyz":
@@ -735,7 +714,6 @@ def n_nearby_leafs(self, proposal, radius):
735714
int
736715
Number of nearby leaf nodes within a specified radius from
737716
a proposal.
738-
739717
"""
740718
xyz = self.proposal_midpoint(proposal)
741719
return len(self.query_kdtree(xyz, radius, "leaf")) - 1
@@ -756,7 +734,6 @@ def dist(self, i, j):
756734
-------
757735
float
758736
Euclidean distance between nodes "i" and "j".
759-
760737
"""
761738
return geometry.dist(self.node_xyz[i], self.node_xyz[j])
762739

@@ -795,7 +772,6 @@ def edge_attr(self, i, key="xyz", ignore=False):
795772
List[numpy.ndarray]
796773
Edge attribute specified by "key" for all edges connected to the
797774
given node.
798-
799775
"""
800776
attrs = list()
801777
for j in self.neighbors(i):
@@ -854,7 +830,6 @@ def get_leafs(self):
854830
-------
855831
List[int]
856832
Leaf nodes in graph.
857-
858833
"""
859834
return [i for i in self.nodes if self.degree[i] == 1]
860835

@@ -864,8 +839,8 @@ def is_soma(self, i):
864839
865840
Parameters
866841
----------
867-
node_or_swc : str
868-
node or swc id to be checked.
842+
i : str
843+
Node ID.
869844
870845
Returns
871846
-------
@@ -1020,7 +995,3 @@ def avg_radius(radii_list):
1020995
end = max(min(16, len(radii) - 1), 1)
1021996
avg += np.mean(radii[0:end]) / len(radii_list)
1022997
return avg
1023-
1024-
1025-
def reformat(segment_id):
1026-
return int(segment_id.split(".")[0])

src/deep_neurographs/proposal_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def trim_endpoints_at_proposal(fragments_graph, proposal, max_length):
313313
trim_to_idx(fragments_graph, i, idx_i)
314314
trim_to_idx(fragments_graph, j, idx_j)
315315

316+
316317
def find_closest_pair(pts1, pts2):
317318
best_dist, best_idxs = np.inf, (0, 0)
318319
i, length1 = -1, 0
@@ -412,7 +413,7 @@ def compute_dot(branch1, branch2, idx1, idx2):
412413
"""
413414
# Initializations
414415
midpoint = geometry.midpoint(branch1[idx1], branch2[idx2])
415-
b1 = branch1 - midpoint
416+
b1 = branch1 - midpoint
416417
b2 = branch2 - midpoint
417418

418419
# Main

src/deep_neurographs/utils/geometry_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import networkx as nx
1616
import numpy as np
1717

18-
from deep_neurographs.utils import graph_util as gutil, img_util
18+
from deep_neurographs.utils import img_util
1919

2020

2121
# --- Directionals ---

src/deep_neurographs/utils/graph_util.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from collections import defaultdict, deque
2828
from concurrent.futures import (
29-
as_completed, ProcessPoolExecutor, ThreadPoolExecutor,
29+
as_completed, ProcessPoolExecutor, ThreadPoolExecutor
3030
)
3131
from random import sample
3232
from scipy.spatial import KDTree
@@ -219,32 +219,36 @@ def extract(self, swc_dict):
219219
Dictionary that each contains the components of an irreducible
220220
subgraph.
221221
"""
222-
graph = self.to_graph(swc_dict)
223-
irreducibles = deque()
224-
high_risk_cnt = 0
225-
if self.satifies_path_length_condition(graph):
226-
# Check for soma merges
227-
if len(graph.graph["soma_nodes"]) > 1:
228-
self.remove_soma_merges(graph)
229-
230-
# Check for high risk merges
231-
if self.remove_high_risk_merges_bool:
232-
high_risk_cnt = self.remove_high_risk_merges(graph)
222+
try:
223+
graph = self.to_graph(swc_dict)
224+
irreducibles = deque()
225+
high_risk_cnt = 0
226+
if self.satifies_path_length_condition(graph):
227+
# Check for soma merges
228+
if len(graph.graph["soma_nodes"]) > 1:
229+
self.remove_soma_merges(graph)
230+
231+
# Check for high risk merges
232+
if self.remove_high_risk_merges_bool:
233+
high_risk_cnt = self.remove_high_risk_merges(graph)
234+
235+
# Extract irreducibles
236+
i = 0
237+
leafs = set(get_leafs(graph))
238+
while leafs:
239+
# Extract for connected component
240+
leaf = util.sample_once(leafs)
241+
irreducibles_i, visited = self.get_irreducibles(graph, leaf)
242+
leafs -= visited
233243

234-
# Extract irreducibles
235-
i = 0
236-
leafs = set(get_leafs(graph))
237-
while leafs:
238-
# Extract for connected component
239-
leaf = util.sample_once(leafs)
240-
irreducibles_i, visited = self.get_irreducibles(graph, leaf)
241-
leafs -= visited
242-
243-
# Store results
244-
if irreducibles_i:
245-
irreducibles_i["swc_id"] = f"{graph.graph['segment_id']}.{i}"
246-
irreducibles.append(irreducibles_i)
247-
i += 1
244+
# Store results
245+
if irreducibles_i:
246+
swc_id = f"{graph.graph['segment_id']}.{i}"
247+
irreducibles_i["swc_id"] = swc_id
248+
irreducibles.append(irreducibles_i)
249+
i += 1
250+
except Exception as e:
251+
print("Exception:", e)
248252
return irreducibles, high_risk_cnt
249253

250254
def get_irreducibles(self, graph, source):
@@ -345,13 +349,13 @@ def remove_high_risk_merges(self, graph, max_dist=7):
345349
"""
346350
high_risk_cnt = 0
347351
nodes = set()
348-
branchings = [i for i in graph.nodes if graph.degree[i] > 2]
352+
branchings = set([i for i in graph.nodes if graph.degree[i] > 2])
349353
while branchings:
350354
# Initializations
351355
root = branchings.pop()
352-
hit_branching = False
356+
hit_branchings = set()
353357
queue = [(root, 0)]
354-
visited = set({root})
358+
visited = set(queue)
355359

356360
# Check if close to soma
357361
soma_dist = self.dist_from_soma(graph.graph["xyz"][root])
@@ -363,7 +367,7 @@ def remove_high_risk_merges(self, graph, max_dist=7):
363367
# Visit node
364368
i, dist_i = queue.pop()
365369
if graph.degree[i] > 2 and i != root:
366-
hit_branching = True
370+
hit_branchings.add(i)
367371

368372
# Update queue
369373
for j in graph.neighbors(i):
@@ -373,10 +377,10 @@ def remove_high_risk_merges(self, graph, max_dist=7):
373377
visited.add(j)
374378

375379
# Determine whether to remove visited nodes
376-
if hit_branching or graph.degree(root) > 3:
380+
if hit_branchings or graph.degree(root) > 3:
377381
nodes = nodes.union(visited)
378-
high_risk_cnt += 0.5
379-
382+
high_risk_cnt += 1
383+
branchings -= hit_branchings
380384
graph.remove_nodes_from(nodes)
381385
return high_risk_cnt
382386

src/deep_neurographs/visualization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from plotly.subplots import make_subplots
1212

1313
import networkx as nx
14-
import numpy as np
1514
import plotly.colors as plc
1615
import plotly.graph_objects as go
1716

0 commit comments

Comments
 (0)