Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 9 additions & 38 deletions src/deep_neurographs/fragments_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
3. Add Irreducibles
to do...
"""

from concurrent.futures import as_completed, ThreadPoolExecutor
Expand All @@ -41,10 +40,7 @@

from deep_neurographs import proposal_generation
from deep_neurographs.utils import (
geometry_util as geometry,
graph_util as gutil,
swc_util,
util,
geometry_util as geometry, graph_util as gutil, util,
)
from deep_neurographs.machine_learning import groundtruth_generation

Expand Down Expand Up @@ -106,7 +102,6 @@ def __init__(
Returns
-------
None
"""
# Call parent class
super(FragmentsGraph, self).__init__()
Expand Down Expand Up @@ -153,7 +148,6 @@ def load_fragments(self, fragments_pointer):
Returns
-------
None
"""
# Extract irreducible components from SWC files
irreducibles = self.graph_loader.run(fragments_pointer)
Expand Down Expand Up @@ -188,7 +182,6 @@ def add_irreducibles(self, irreducibles, component_id):
Returns
-------
None
"""
# SWC ID
self.component_id_to_swc_id[component_id] = irreducibles["swc_id"]
Expand All @@ -203,7 +196,7 @@ def add_irreducibles(self, irreducibles, component_id):

def _add_nodes(self, node_dict, component_id):
"""
Adds nodes to the graph from a dictionary of node attributes and
Adds nodes to the graph from a dictionary of node attributes and
returns a mapping from original node IDs to the new graph node IDs.
Parameters
Expand Down Expand Up @@ -247,7 +240,6 @@ def _add_edge(self, edge, attrs):
Returns
-------
None
"""
i, j = tuple(edge)
self.add_edge(i, j, radius=attrs["radius"], xyz=attrs["xyz"])
Expand Down Expand Up @@ -293,7 +285,6 @@ def init_kdtree(self, node_type=None):
Returns
-------
None
"""
if node_type == "leaf":
self.leaf_kdtree = self.get_kdtree(node_type=node_type)
Expand All @@ -315,7 +306,6 @@ def get_kdtree(self, node_type=None):
-------
KDTree
KD-Tree generated from xyz coordinates across all nodes and edges.
"""
# Get xyz coordinates
if node_type == "leaf":
Expand Down Expand Up @@ -344,7 +334,6 @@ def query_kdtree(self, xyz, d, node_type=None):
generator[tuple]
Generator that generates the xyz coordinates cooresponding to all
nodes within a distance of "d" from "xyz".
"""
if node_type == "leaf":
return geometry.query_ball(self.leaf_kdtree, xyz, d)
Expand Down Expand Up @@ -387,7 +376,6 @@ def generate_proposals(
Returns
-------
None
"""
# Initializations
self.reset_proposals()
Expand Down Expand Up @@ -425,7 +413,6 @@ def reset_proposals(self):
Returns
-------
None
"""
self.proposals = set()
for i in self.nodes:
Expand Down Expand Up @@ -461,7 +448,6 @@ def add_proposal(self, i, j):
Returns
-------
None
"""
proposal = frozenset({i, j})
self.nodes[i]["proposals"].add(j)
Expand All @@ -475,12 +461,11 @@ def remove_proposal(self, proposal):
Parameters
----------
proposal : Frozenset[int]
Pair of node ids corresponding to a proposal.
Pair of node IDs corresponding to a proposal.
Returns
-------
None
"""
i, j = tuple(proposal)
self.nodes[i]["proposals"].remove(j)
Expand All @@ -495,14 +480,13 @@ def is_single_proposal(self, proposal):
Parameters
----------
proposal : Frozenset[int]
Pair of node ids corresponding to a proposal.
Pair of node IDs corresponding to a proposal.
Returns
-------
bool
Indiciation of "proposal" is the only proposal generated for the
corresponding nodes.
corresponding nodes.
"""
i, j = tuple(proposal)
single_i = len(self.nodes[i]["proposals"]) == 1
Expand All @@ -529,7 +513,6 @@ def is_valid_proposal(self, leaf, i, complex_bool):
-------
bool
Indication of whether proposal is valid.
"""
if i is not None:
skip_soma = self.is_soma(i) and self.is_soma(leaf)
Expand All @@ -549,9 +532,8 @@ def list_proposals(self):
Returns
-------
list
List[Frozenset[int]]
Proposals.
"""
return list(self.proposals)

Expand All @@ -568,13 +550,12 @@ def n_proposals(self):
-------
int
Number of proposals in the graph.
"""
return len(self.proposals)

def is_simple(self, proposal):
"""
Determines whether both nodes in a proposal are leafs.
Checks if both nodes in a proposal are leafs.
Parameters
----------
Expand All @@ -585,7 +566,6 @@ def is_simple(self, proposal):
-------
bool
Indication of whether both nodes in a proposal are leafs.
"""
i, j = tuple(proposal)
return True if self.degree[i] == 1 and self.degree[j] == 1 else False
Expand Down Expand Up @@ -618,7 +598,6 @@ def proposal_attr(self, proposal, key):
-------
numpy.ndarray
Attributes of nodes in "proposal".
"""
i, j = tuple(proposal)
if key == "xyz":
Expand Down Expand Up @@ -735,7 +714,6 @@ def n_nearby_leafs(self, proposal, radius):
int
Number of nearby leaf nodes within a specified radius from
a proposal.
"""
xyz = self.proposal_midpoint(proposal)
return len(self.query_kdtree(xyz, radius, "leaf")) - 1
Expand All @@ -756,7 +734,6 @@ def dist(self, i, j):
-------
float
Euclidean distance between nodes "i" and "j".
"""
return geometry.dist(self.node_xyz[i], self.node_xyz[j])

Expand Down Expand Up @@ -795,7 +772,6 @@ def edge_attr(self, i, key="xyz", ignore=False):
List[numpy.ndarray]
Edge attribute specified by "key" for all edges connected to the
given node.
"""
attrs = list()
for j in self.neighbors(i):
Expand Down Expand Up @@ -854,7 +830,6 @@ def get_leafs(self):
-------
List[int]
Leaf nodes in graph.
"""
return [i for i in self.nodes if self.degree[i] == 1]

Expand All @@ -864,8 +839,8 @@ def is_soma(self, i):
Parameters
----------
node_or_swc : str
node or swc id to be checked.
i : str
Node ID.
Returns
-------
Expand Down Expand Up @@ -1020,7 +995,3 @@ def avg_radius(radii_list):
end = max(min(16, len(radii) - 1), 1)
avg += np.mean(radii[0:end]) / len(radii_list)
return avg


def reformat(segment_id):
return int(segment_id.split(".")[0])
3 changes: 2 additions & 1 deletion src/deep_neurographs/proposal_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def trim_endpoints_at_proposal(fragments_graph, proposal, max_length):
trim_to_idx(fragments_graph, i, idx_i)
trim_to_idx(fragments_graph, j, idx_j)


def find_closest_pair(pts1, pts2):
best_dist, best_idxs = np.inf, (0, 0)
i, length1 = -1, 0
Expand Down Expand Up @@ -412,7 +413,7 @@ def compute_dot(branch1, branch2, idx1, idx2):
"""
# Initializations
midpoint = geometry.midpoint(branch1[idx1], branch2[idx2])
b1 = branch1 - midpoint
b1 = branch1 - midpoint
b2 = branch2 - midpoint

# Main
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/geometry_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import networkx as nx
import numpy as np

from deep_neurographs.utils import graph_util as gutil, img_util
from deep_neurographs.utils import img_util


# --- Directionals ---
Expand Down
70 changes: 37 additions & 33 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from collections import defaultdict, deque
from concurrent.futures import (
as_completed, ProcessPoolExecutor, ThreadPoolExecutor,
as_completed, ProcessPoolExecutor, ThreadPoolExecutor
)
from random import sample
from scipy.spatial import KDTree
Expand Down Expand Up @@ -219,32 +219,36 @@ def extract(self, swc_dict):
Dictionary that each contains the components of an irreducible
subgraph.
"""
graph = self.to_graph(swc_dict)
irreducibles = deque()
high_risk_cnt = 0
if self.satifies_path_length_condition(graph):
# Check for soma merges
if len(graph.graph["soma_nodes"]) > 1:
self.remove_soma_merges(graph)

# Check for high risk merges
if self.remove_high_risk_merges_bool:
high_risk_cnt = self.remove_high_risk_merges(graph)
try:
graph = self.to_graph(swc_dict)
irreducibles = deque()
high_risk_cnt = 0
if self.satifies_path_length_condition(graph):
# Check for soma merges
if len(graph.graph["soma_nodes"]) > 1:
self.remove_soma_merges(graph)

# Check for high risk merges
if self.remove_high_risk_merges_bool:
high_risk_cnt = self.remove_high_risk_merges(graph)

# Extract irreducibles
i = 0
leafs = set(get_leafs(graph))
while leafs:
# Extract for connected component
leaf = util.sample_once(leafs)
irreducibles_i, visited = self.get_irreducibles(graph, leaf)
leafs -= visited

# Extract irreducibles
i = 0
leafs = set(get_leafs(graph))
while leafs:
# Extract for connected component
leaf = util.sample_once(leafs)
irreducibles_i, visited = self.get_irreducibles(graph, leaf)
leafs -= visited

# Store results
if irreducibles_i:
irreducibles_i["swc_id"] = f"{graph.graph['segment_id']}.{i}"
irreducibles.append(irreducibles_i)
i += 1
# Store results
if irreducibles_i:
swc_id = f"{graph.graph['segment_id']}.{i}"
irreducibles_i["swc_id"] = swc_id
irreducibles.append(irreducibles_i)
i += 1
except Exception as e:
print("Exception:", e)
return irreducibles, high_risk_cnt

def get_irreducibles(self, graph, source):
Expand Down Expand Up @@ -345,13 +349,13 @@ def remove_high_risk_merges(self, graph, max_dist=7):
"""
high_risk_cnt = 0
nodes = set()
branchings = [i for i in graph.nodes if graph.degree[i] > 2]
branchings = set([i for i in graph.nodes if graph.degree[i] > 2])
while branchings:
# Initializations
root = branchings.pop()
hit_branching = False
hit_branchings = set()
queue = [(root, 0)]
visited = set({root})
visited = set(queue)

# Check if close to soma
soma_dist = self.dist_from_soma(graph.graph["xyz"][root])
Expand All @@ -363,7 +367,7 @@ def remove_high_risk_merges(self, graph, max_dist=7):
# Visit node
i, dist_i = queue.pop()
if graph.degree[i] > 2 and i != root:
hit_branching = True
hit_branchings.add(i)

# Update queue
for j in graph.neighbors(i):
Expand All @@ -373,10 +377,10 @@ def remove_high_risk_merges(self, graph, max_dist=7):
visited.add(j)

# Determine whether to remove visited nodes
if hit_branching or graph.degree(root) > 3:
if hit_branchings or graph.degree(root) > 3:
nodes = nodes.union(visited)
high_risk_cnt += 0.5

high_risk_cnt += 1
branchings -= hit_branchings
graph.remove_nodes_from(nodes)
return high_risk_cnt

Expand Down
1 change: 0 additions & 1 deletion src/deep_neurographs/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from plotly.subplots import make_subplots

import networkx as nx
import numpy as np
import plotly.colors as plc
import plotly.graph_objects as go

Expand Down
Loading