Skip to content

Commit 12291b0

Browse files
authored
Self-loops management in KNNGraph and RadiusGraph (mathLab#522)
* Add self-loop option to RadiusGraph and KNNGraph
1 parent f48da47 commit 12291b0

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

pina/graph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch_geometric.data import Data, Batch
55
from torch_geometric.utils import to_undirected
6+
from torch_geometric.utils.loop import remove_self_loops
67
from .label_tensor import LabelTensor
78
from .utils import check_consistency, is_function
89

@@ -209,6 +210,7 @@ def __new__(
209210
x=None,
210211
edge_attr=False,
211212
custom_edge_func=None,
213+
loop=True,
212214
**kwargs,
213215
):
214216
"""
@@ -224,18 +226,19 @@ def __new__(
224226
:param x: Optional tensor of node features of shape ``(N, F)``, where
225227
``F`` is the number of features per node.
226228
:type x: torch.Tensor | LabelTensor, optional
227-
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
228-
, where ``F`` is the number of features per edge.
229-
:type edge_attr: torch.Tensor, optional
229+
:param bool edge_attr: Whether to compute the edge attributes.
230230
:param custom_edge_func: A custom function to compute edge attributes.
231231
If provided, overrides ``edge_attr``.
232232
:type custom_edge_func: Callable, optional
233+
:param bool loop: Whether to include self-loops.
233234
:param kwargs: Additional keyword arguments passed to the
234235
:class:`~pina.graph.Graph` class constructor.
235236
:return: A :class:`~pina.graph.Graph` instance constructed using the
236237
provided information.
237238
:rtype: Graph
238239
"""
240+
if not loop:
241+
edge_index = remove_self_loops(edge_index)[0]
239242
edge_attr = cls._create_edge_attr(
240243
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
241244
)
@@ -374,11 +377,8 @@ def compute_knn_graph(points, neighbours):
374377
representing the edge indices of the graph.
375378
:rtype: torch.Tensor
376379
"""
377-
378380
dist = torch.cdist(points, points, p=2)
379-
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
380-
:, 1:
381-
]
381+
knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
382382
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
383383
col = knn_indices.flatten()
384384
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)

tests/test_graph.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def test_build_graph(x, pos):
6767
),
6868
],
6969
)
70-
def test_build_radius_graph(x, pos):
71-
graph = RadiusGraph(x=x, pos=pos, radius=0.5)
70+
@pytest.mark.parametrize("loop", [True, False])
71+
def test_build_radius_graph(x, pos, loop):
72+
graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop)
7273
assert hasattr(graph, "x")
7374
assert hasattr(graph, "pos")
7475
assert hasattr(graph, "edge_index")
@@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
8485
assert graph.pos.labels == pos.labels
8586
else:
8687
assert isinstance(graph.pos, torch.Tensor)
88+
if not loop:
89+
assert (
90+
len(
91+
torch.nonzero(
92+
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
93+
)[0]
94+
)
95+
== 0
96+
) # Detect self loops
8797

8898

8999
@pytest.mark.parametrize(
@@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
168178
),
169179
],
170180
)
171-
def test_build_knn_graph(x, pos):
172-
graph = KNNGraph(x=x, pos=pos, neighbours=2)
181+
@pytest.mark.parametrize("loop", [True, False])
182+
def test_build_knn_graph(x, pos, loop):
183+
graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop)
173184
assert hasattr(graph, "x")
174185
assert hasattr(graph, "pos")
175186
assert hasattr(graph, "edge_index")
@@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
186197
else:
187198
assert isinstance(graph.pos, torch.Tensor)
188199
assert graph.edge_attr is None
200+
self_loops = len(
201+
torch.nonzero(
202+
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
203+
)[0]
204+
)
205+
if loop:
206+
assert self_loops != 0
207+
else:
208+
assert self_loops == 0
189209

190210

191211
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)