Skip to content

Commit 4110851

Browse files
tests and minor fixes
1 parent 1525549 commit 4110851

12 files changed

+453
-19
lines changed

pina/model/block/message_passing/deep_tensor_network_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
flow="source_to_target",
4141
):
4242
"""
43-
Initialization of the :class:`DeepTensorNetworkBlocklock` class.
43+
Initialization of the :class:`DeepTensorNetworkBlock` class.
4444
4545
:param int node_feature_dim: The dimension of the node features.
4646
:param int edge_feature_dim: The dimension of the edge features.
@@ -68,7 +68,7 @@ def __init__(
6868
check_positive_integer(edge_feature_dim, strict=True)
6969

7070
# Activation function
71-
self.activation = activation
71+
self.activation = activation()
7272

7373
# Layer for processing node features
7474
self.node_layer = torch.nn.Linear(

pina/model/block/message_passing/egnn_block.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
source node. See :class:`torch_geometric.nn.MessagePassing` for more
7777
details. Default is "source_to_target".
7878
:raises AssertionError: If `node_feature_dim` is not a positive integer.
79-
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
79+
:raises AssertionError: If `edge_feature_dim` is a negative integer.
8080
:raises AssertionError: If `pos_dim` is not a positive integer.
8181
:raises AssertionError: If `hidden_dim` is not a positive integer.
8282
:raises AssertionError: If `n_message_layers` is not a positive integer.
@@ -86,7 +86,7 @@ def __init__(
8686

8787
# Check values
8888
check_positive_integer(node_feature_dim, strict=True)
89-
check_positive_integer(edge_feature_dim, strict=True)
89+
check_positive_integer(edge_feature_dim, strict=False)
9090
check_positive_integer(pos_dim, strict=True)
9191
check_positive_integer(hidden_dim, strict=True)
9292
check_positive_integer(n_message_layers, strict=True)
@@ -110,7 +110,7 @@ def __init__(
110110
func=activation,
111111
)
112112

113-
def forward(self, x, pos, edge_index, edge_attr):
113+
def forward(self, x, pos, edge_index, edge_attr=None):
114114
"""
115115
Forward pass of the block, triggering the message-passing routine.
116116
@@ -146,7 +146,11 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
146146
:rtype: torch.Tensor
147147
"""
148148
dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2
149-
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
149+
if edge_attr is None:
150+
input_ = torch.cat((x_i, x_j, dist), dim=-1)
151+
else:
152+
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
153+
150154
return self.message_net(input_)
151155

152156
def update(self, message, x, pos, edge_index):
@@ -169,4 +173,4 @@ def update(self, message, x, pos, edge_index):
169173
# Update the node positions
170174
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
171175
pos = pos + message / c
172-
return pos, x
176+
return x, pos

pina/model/block/message_passing/interaction_network_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def __init__(
9999
# Update network
100100
self.update_net = FeedForward(
101101
input_dimensions=node_feature_dim + hidden_dim,
102-
output_dimensions=hidden_dim,
103-
inner_size=node_feature_dim,
102+
output_dimensions=node_feature_dim,
103+
inner_size=hidden_dim,
104104
n_layers=n_update_layers,
105105
func=activation,
106106
)

pina/model/block/message_passing/radial_field_network_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(self, x, edge_index):
9797
"""
9898
return self.propagate(edge_index=edge_index, x=x)
9999

100-
def message(self, x_j, x_i):
100+
def message(self, x_i, x_j):
101101
"""
102102
Compute the message to be passed between nodes and edges.
103103

pina/model/block/message_passing/schnet_block.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ def forward(self, x, pos, edge_index):
120120
121121
:param x: The node features.
122122
:type x: torch.Tensor | LabelTensor
123-
:param torch.Tensor edge_index: The edge indices. In the original formulation,
124-
the messages are aggregated from all nodes, not only from the neighbours.
123+
:param torch.Tensor edge_index: The edge indices.
125124
:return: The updated node features.
126125
:rtype: torch.Tensor
127126
"""

pina/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def check_positive_integer(value, strict=True):
202202
:param int value: The value to check.
203203
:param bool strict: If True, the value must be strictly positive.
204204
Default is True.
205-
:raises ValueError: If the value is not a positive integer.
205+
:raises AssertionError: If the value is not a positive integer.
206206
"""
207207
if strict:
208208
assert (
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import torch
3+
from pina.model.block.message_passing import DeepTensorNetworkBlock
4+
5+
# Data for testing
6+
x = torch.rand(10, 3)
7+
edge_index = torch.randint(0, 10, (2, 20))
8+
edge_attr = torch.randn(20, 2)
9+
10+
11+
@pytest.mark.parametrize("node_feature_dim", [1, 3])
12+
@pytest.mark.parametrize("edge_feature_dim", [3, 5])
13+
def test_constructor(node_feature_dim, edge_feature_dim):
14+
15+
DeepTensorNetworkBlock(
16+
node_feature_dim=node_feature_dim,
17+
edge_feature_dim=edge_feature_dim,
18+
)
19+
20+
# Should fail if node_feature_dim is negative
21+
with pytest.raises(AssertionError):
22+
DeepTensorNetworkBlock(
23+
node_feature_dim=-1, edge_feature_dim=edge_feature_dim
24+
)
25+
26+
# Should fail if edge_feature_dim is negative
27+
with pytest.raises(AssertionError):
28+
DeepTensorNetworkBlock(
29+
node_feature_dim=node_feature_dim, edge_feature_dim=-1
30+
)
31+
32+
33+
def test_forward():
34+
35+
model = DeepTensorNetworkBlock(
36+
node_feature_dim=x.shape[1],
37+
edge_feature_dim=edge_attr.shape[1],
38+
)
39+
40+
output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr)
41+
assert output_.shape == x.shape
42+
43+
44+
def test_backward():
45+
46+
model = DeepTensorNetworkBlock(
47+
node_feature_dim=x.shape[1],
48+
edge_feature_dim=edge_attr.shape[1],
49+
)
50+
51+
output_ = model(
52+
edge_index=edge_index,
53+
x=x.requires_grad_(),
54+
edge_attr=edge_attr.requires_grad_(),
55+
)
56+
57+
loss = torch.mean(output_)
58+
loss.backward()
59+
assert x.grad.shape == x.shape
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pytest
2+
import torch
3+
from pina.model.block.message_passing import EnEquivariantNetworkBlock
4+
5+
# Data for testing
6+
x = torch.rand(10, 4)
7+
pos = torch.rand(10, 3)
8+
edge_index = torch.randint(0, 10, (2, 20))
9+
edge_attr = torch.randn(20, 2)
10+
11+
12+
@pytest.mark.parametrize("node_feature_dim", [1, 3])
13+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
14+
@pytest.mark.parametrize("pos_dim", [2, 3])
15+
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
16+
17+
EnEquivariantNetworkBlock(
18+
node_feature_dim=node_feature_dim,
19+
edge_feature_dim=edge_feature_dim,
20+
pos_dim=pos_dim,
21+
hidden_dim=64,
22+
n_message_layers=2,
23+
n_update_layers=2,
24+
)
25+
26+
# Should fail if node_feature_dim is negative
27+
with pytest.raises(AssertionError):
28+
EnEquivariantNetworkBlock(
29+
node_feature_dim=-1,
30+
edge_feature_dim=edge_feature_dim,
31+
pos_dim=pos_dim,
32+
)
33+
34+
# Should fail if edge_feature_dim is negative
35+
with pytest.raises(AssertionError):
36+
EnEquivariantNetworkBlock(
37+
node_feature_dim=node_feature_dim,
38+
edge_feature_dim=-1,
39+
pos_dim=pos_dim,
40+
)
41+
42+
# Should fail if pos_dim is negative
43+
with pytest.raises(AssertionError):
44+
EnEquivariantNetworkBlock(
45+
node_feature_dim=node_feature_dim,
46+
edge_feature_dim=edge_feature_dim,
47+
pos_dim=-1,
48+
)
49+
50+
# Should fail if hidden_dim is negative
51+
with pytest.raises(AssertionError):
52+
EnEquivariantNetworkBlock(
53+
node_feature_dim=node_feature_dim,
54+
edge_feature_dim=edge_feature_dim,
55+
pos_dim=pos_dim,
56+
hidden_dim=-1,
57+
)
58+
59+
# Should fail if n_message_layers is negative
60+
with pytest.raises(AssertionError):
61+
EnEquivariantNetworkBlock(
62+
node_feature_dim=node_feature_dim,
63+
edge_feature_dim=edge_feature_dim,
64+
pos_dim=pos_dim,
65+
n_message_layers=-1,
66+
)
67+
68+
# Should fail if n_update_layers is negative
69+
with pytest.raises(AssertionError):
70+
EnEquivariantNetworkBlock(
71+
node_feature_dim=node_feature_dim,
72+
edge_feature_dim=edge_feature_dim,
73+
pos_dim=pos_dim,
74+
n_update_layers=-1,
75+
)
76+
77+
78+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
79+
def test_forward(edge_feature_dim):
80+
81+
model = EnEquivariantNetworkBlock(
82+
node_feature_dim=x.shape[1],
83+
edge_feature_dim=edge_feature_dim,
84+
pos_dim=pos.shape[1],
85+
hidden_dim=64,
86+
n_message_layers=2,
87+
n_update_layers=2,
88+
)
89+
90+
if edge_feature_dim == 0:
91+
output_ = model(edge_index=edge_index, x=x, pos=pos)
92+
else:
93+
output_ = model(
94+
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
95+
)
96+
97+
assert output_[0].shape == x.shape
98+
assert output_[1].shape == pos.shape
99+
100+
101+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
102+
def test_backward(edge_feature_dim):
103+
104+
model = EnEquivariantNetworkBlock(
105+
node_feature_dim=x.shape[1],
106+
edge_feature_dim=edge_feature_dim,
107+
pos_dim=pos.shape[1],
108+
hidden_dim=64,
109+
n_message_layers=2,
110+
n_update_layers=2,
111+
)
112+
113+
if edge_feature_dim == 0:
114+
output_ = model(
115+
edge_index=edge_index,
116+
x=x.requires_grad_(),
117+
pos=pos.requires_grad_(),
118+
)
119+
else:
120+
output_ = model(
121+
edge_index=edge_index,
122+
x=x.requires_grad_(),
123+
pos=pos.requires_grad_(),
124+
edge_attr=edge_attr.requires_grad_(),
125+
)
126+
127+
loss = torch.mean(output_[0])
128+
loss.backward()
129+
assert x.grad.shape == x.shape
130+
assert pos.grad.shape == pos.shape
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
import torch
3+
from pina.model.block.message_passing import InteractionNetworkBlock
4+
5+
# Data for testing
6+
x = torch.rand(10, 3)
7+
edge_index = torch.randint(0, 10, (2, 20))
8+
edge_attr = torch.randn(20, 2)
9+
10+
11+
@pytest.mark.parametrize("node_feature_dim", [1, 3])
12+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
13+
def test_constructor(node_feature_dim, edge_feature_dim):
14+
15+
InteractionNetworkBlock(
16+
node_feature_dim=node_feature_dim,
17+
edge_feature_dim=edge_feature_dim,
18+
hidden_dim=64,
19+
n_message_layers=2,
20+
n_update_layers=2,
21+
)
22+
23+
# Should fail if node_feature_dim is negative
24+
with pytest.raises(AssertionError):
25+
InteractionNetworkBlock(node_feature_dim=-1)
26+
27+
# Should fail if edge_feature_dim is negative
28+
with pytest.raises(AssertionError):
29+
InteractionNetworkBlock(node_feature_dim=3, edge_feature_dim=-1)
30+
31+
# Should fail if hidden_dim is negative
32+
with pytest.raises(AssertionError):
33+
InteractionNetworkBlock(node_feature_dim=3, hidden_dim=-1)
34+
35+
# Should fail if n_message_layers is negative
36+
with pytest.raises(AssertionError):
37+
InteractionNetworkBlock(node_feature_dim=3, n_message_layers=-1)
38+
39+
# Should fail if n_update_layers is negative
40+
with pytest.raises(AssertionError):
41+
InteractionNetworkBlock(node_feature_dim=3, n_update_layers=-1)
42+
43+
44+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
45+
def test_forward(edge_feature_dim):
46+
47+
model = InteractionNetworkBlock(
48+
node_feature_dim=x.shape[1],
49+
edge_feature_dim=edge_feature_dim,
50+
hidden_dim=64,
51+
n_message_layers=2,
52+
n_update_layers=2,
53+
)
54+
55+
if edge_feature_dim == 0:
56+
output_ = model(edge_index=edge_index, x=x)
57+
else:
58+
output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr)
59+
assert output_.shape == x.shape
60+
61+
62+
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
63+
def test_backward(edge_feature_dim):
64+
65+
model = InteractionNetworkBlock(
66+
node_feature_dim=x.shape[1],
67+
edge_feature_dim=edge_feature_dim,
68+
hidden_dim=64,
69+
n_message_layers=2,
70+
n_update_layers=2,
71+
)
72+
73+
if edge_feature_dim == 0:
74+
output_ = model(edge_index=edge_index, x=x.requires_grad_())
75+
else:
76+
output_ = model(
77+
edge_index=edge_index,
78+
x=x.requires_grad_(),
79+
edge_attr=edge_attr.requires_grad_(),
80+
)
81+
82+
loss = torch.mean(output_)
83+
loss.backward()
84+
assert x.grad.shape == x.shape

0 commit comments

Comments
 (0)