Skip to content

Commit fbc0382

Browse files
fix egnn + equivariance/invariance tests
Co-authored-by: Dario Coscia <[email protected]>
1 parent 2cd5a5e commit fbc0382

File tree

4 files changed

+147
-12
lines changed

4 files changed

+147
-12
lines changed

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class EnEquivariantNetworkBlock(MessagePassing):
1111
"""
1212
Implementation of the E(n) Equivariant Graph Neural Network block.
13-
1413
This block is used to perform message-passing between nodes and edges in a
1514
graph neural network, following the scheme proposed by Satorras et al. in
1615
2021. It serves as an inner block in a larger graph neural network
@@ -102,14 +101,24 @@ def __init__(
102101
)
103102

104103
# Layer for updating the node features
105-
self.update_net = FeedForward(
104+
self.update_feat_net = FeedForward(
106105
input_dimensions=node_feature_dim + pos_dim,
107106
output_dimensions=node_feature_dim,
108107
inner_size=hidden_dim,
109108
n_layers=n_update_layers,
110109
func=activation,
111110
)
112111

112+
# Layer for updating the node positions
113+
# The output dimension is set to 1 for equivariant updates
114+
self.update_pos_net = FeedForward(
115+
input_dimensions=pos_dim,
116+
output_dimensions=1,
117+
inner_size=hidden_dim,
118+
n_layers=n_update_layers,
119+
func=activation,
120+
)
121+
113122
def forward(self, x, pos, edge_index, edge_attr=None):
114123
"""
115124
Forward pass of the block, triggering the message-passing routine.
@@ -143,22 +152,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
143152
:param edge_attr: The edge attributes.
144153
:type edge_attr: torch.Tensor | LabelTensor
145154
:return: The message to be passed.
146-
:rtype: torch.Tensor
155+
:rtype: tuple(torch.Tensor, torch.Tensor)
147156
"""
148-
dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2
157+
# Compute the euclidean distance between the sender and recipient nodes
158+
diff = pos_i - pos_j
159+
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
160+
161+
# Compute the message input
149162
if edge_attr is None:
150163
input_ = torch.cat((x_i, x_j, dist), dim=-1)
151164
else:
152165
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
153166

154-
return self.message_net(input_)
167+
# Compute the messages and their equivariant counterpart
168+
m_ij = self.message_net(input_)
169+
message = diff * self.update_pos_net(m_ij)
170+
171+
return message, m_ij
155172

156-
def update(self, message, x, pos, edge_index):
173+
def aggregate(self, inputs, index, ptr=None, dim_size=None):
174+
"""
175+
Aggregate the messages at the nodes during message passing.
176+
177+
This method receives a tuple of tensors corresponding to the messages
178+
to be aggregated. Both messages are aggregated separately according to
179+
the specified aggregation scheme.
180+
181+
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
182+
aggregate.
183+
:param index: The indices of target nodes for each message. This tensor
184+
specifies which node each message is aggregated into.
185+
:type index: torch.Tensor | LabelTensor
186+
:param ptr: Optional tensor to specify the slices of messages for each
187+
node (used in some aggregation strategies). Default is None.
188+
:type ptr: torch.Tensor | LabelTensor
189+
:param int dim_size: Optional size of the output dimension, i.e.,
190+
number of nodes. Default is None.
191+
:return: Tuple of aggregated tensors corresponding to (aggregated
192+
messages for position updates, aggregated messages for feature
193+
updates).
194+
:rtype: tuple(torch.Tensor, torch.Tensor)
195+
"""
196+
# Unpack the messages from the inputs
197+
message, m_ij = inputs
198+
199+
# Aggregate messages as usual using self.aggr method
200+
agg_message = super().aggregate(message, index, ptr, dim_size)
201+
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
202+
203+
return agg_message, agg_m_ij
204+
205+
def update(self, aggregated_inputs, x, pos, edge_index):
157206
"""
158207
Update the node features and the node coordinates with the received
159208
messages.
160209
161-
:param torch.Tensor message: The message to be passed.
210+
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
162211
:param x: The node features.
163212
:type x: torch.Tensor | LabelTensor
164213
:param pos: The euclidean coordinates of the nodes.
@@ -167,10 +216,14 @@ def update(self, message, x, pos, edge_index):
167216
:return: The updated node features and node positions.
168217
:rtype: tuple(torch.Tensor, torch.Tensor)
169218
"""
170-
# Update the node features
171-
x = self.update_net(torch.cat((x, message), dim=-1))
219+
# aggregated_inputs is tuple (agg_message, agg_m_ij)
220+
agg_message, agg_m_ij = aggregated_inputs
221+
222+
# Update node features with aggregated m_ij
223+
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
224+
225+
# Degree for normalization of position updates
226+
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
227+
pos = pos + agg_message / c
172228

173-
# Update the node positions
174-
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
175-
pos = pos + message / c
176229
return x, pos

tests/test_messagepassing/test_equivariant_network_block.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,38 @@ def test_backward(edge_feature_dim):
128128
loss.backward()
129129
assert x.grad.shape == x.shape
130130
assert pos.grad.shape == pos.shape
131+
132+
133+
def test_equivariance():
134+
135+
# Graph to be fully connected and undirected
136+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
137+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
138+
139+
# Random rotation (det(rotation) should be 1)
140+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
141+
if torch.det(rotation) < 0:
142+
rotation[:, 0] *= -1
143+
144+
# Random translation
145+
translation = torch.rand(1, pos.shape[-1])
146+
147+
model = EnEquivariantNetworkBlock(
148+
node_feature_dim=x.shape[1],
149+
edge_feature_dim=0,
150+
pos_dim=pos.shape[1],
151+
hidden_dim=64,
152+
n_message_layers=2,
153+
n_update_layers=2,
154+
).eval()
155+
156+
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
157+
h2, pos2 = model(
158+
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
159+
)
160+
161+
# Transform model output
162+
pos1_transformed = (pos1 @ rotation.T) + translation
163+
164+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
165+
assert torch.allclose(h1, h2, atol=1e-5)

tests/test_messagepassing/test_radial_field_network_block.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,28 @@ def test_backward():
6565
loss = torch.mean(output_)
6666
loss.backward()
6767
assert x.grad.shape == x.shape
68+
69+
70+
def test_equivariance():
71+
72+
# Graph to be fully connected and undirected
73+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
74+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
75+
76+
# Random rotation (det(rotation) should be 1)
77+
rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q
78+
if torch.det(rotation) < 0:
79+
rotation[:, 0] *= -1
80+
81+
# Random translation
82+
translation = torch.rand(1, x.shape[-1])
83+
84+
model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval()
85+
86+
pos1 = model(edge_index=edge_index, x=x)
87+
pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation)
88+
89+
# Transform model output
90+
pos1_transformed = (pos1 @ rotation.T) + translation
91+
92+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)

tests/test_messagepassing/test_schnet_block.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,25 @@ def test_backward():
7171
loss = torch.mean(output_)
7272
loss.backward()
7373
assert x.grad.shape == x.shape
74+
75+
76+
def test_invariance():
77+
78+
# Graph to be fully connected and undirected
79+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
80+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
81+
82+
# Random rotation (det(rotation) should be 1)
83+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
84+
if torch.det(rotation) < 0:
85+
rotation[:, 0] *= -1
86+
87+
# Random translation
88+
translation = torch.rand(1, pos.shape[-1])
89+
90+
model = SchnetBlock(node_feature_dim=x.shape[1]).eval()
91+
92+
out1 = model(edge_index=edge_index, x=x, pos=pos)
93+
out2 = model(edge_index=edge_index, x=x, pos=pos @ rotation.T + translation)
94+
95+
assert torch.allclose(out1, out2, atol=1e-5)

0 commit comments

Comments
 (0)