10
10
class EnEquivariantNetworkBlock (MessagePassing ):
11
11
"""
12
12
Implementation of the E(n) Equivariant Graph Neural Network block.
13
-
14
13
This block is used to perform message-passing between nodes and edges in a
15
14
graph neural network, following the scheme proposed by Satorras et al. in
16
15
2021. It serves as an inner block in a larger graph neural network
@@ -102,14 +101,24 @@ def __init__(
102
101
)
103
102
104
103
# Layer for updating the node features
105
- self .update_net = FeedForward (
104
+ self .update_feat_net = FeedForward (
106
105
input_dimensions = node_feature_dim + pos_dim ,
107
106
output_dimensions = node_feature_dim ,
108
107
inner_size = hidden_dim ,
109
108
n_layers = n_update_layers ,
110
109
func = activation ,
111
110
)
112
111
112
+ # Layer for updating the node positions
113
+ # The output dimension is set to 1 for equivariant updates
114
+ self .update_pos_net = FeedForward (
115
+ input_dimensions = pos_dim ,
116
+ output_dimensions = 1 ,
117
+ inner_size = hidden_dim ,
118
+ n_layers = n_update_layers ,
119
+ func = activation ,
120
+ )
121
+
113
122
def forward (self , x , pos , edge_index , edge_attr = None ):
114
123
"""
115
124
Forward pass of the block, triggering the message-passing routine.
@@ -143,22 +152,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
143
152
:param edge_attr: The edge attributes.
144
153
:type edge_attr: torch.Tensor | LabelTensor
145
154
:return: The message to be passed.
146
- :rtype: torch.Tensor
155
+ :rtype: tuple( torch.Tensor, torch.Tensor)
147
156
"""
148
- dist = torch .norm (pos_i - pos_j , dim = - 1 , keepdim = True ) ** 2
157
+ # Compute the euclidean distance between the sender and recipient nodes
158
+ diff = pos_i - pos_j
159
+ dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
160
+
161
+ # Compute the message input
149
162
if edge_attr is None :
150
163
input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
151
164
else :
152
165
input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
153
166
154
- return self .message_net (input_ )
167
+ # Compute the messages and their equivariant counterpart
168
+ m_ij = self .message_net (input_ )
169
+ message = diff * self .update_pos_net (m_ij )
170
+
171
+ return message , m_ij
155
172
156
- def update (self , message , x , pos , edge_index ):
173
+ def aggregate (self , inputs , index , ptr = None , dim_size = None ):
174
+ """
175
+ Aggregate the messages at the nodes during message passing.
176
+
177
+ This method receives a tuple of tensors corresponding to the messages
178
+ to be aggregated. Both messages are aggregated separately according to
179
+ the specified aggregation scheme.
180
+
181
+ :param tuple(torch.Tensor) inputs: Tuple containing two messages to
182
+ aggregate.
183
+ :param index: The indices of target nodes for each message. This tensor
184
+ specifies which node each message is aggregated into.
185
+ :type index: torch.Tensor | LabelTensor
186
+ :param ptr: Optional tensor to specify the slices of messages for each
187
+ node (used in some aggregation strategies). Default is None.
188
+ :type ptr: torch.Tensor | LabelTensor
189
+ :param int dim_size: Optional size of the output dimension, i.e.,
190
+ number of nodes. Default is None.
191
+ :return: Tuple of aggregated tensors corresponding to (aggregated
192
+ messages for position updates, aggregated messages for feature
193
+ updates).
194
+ :rtype: tuple(torch.Tensor, torch.Tensor)
195
+ """
196
+ # Unpack the messages from the inputs
197
+ message , m_ij = inputs
198
+
199
+ # Aggregate messages as usual using self.aggr method
200
+ agg_message = super ().aggregate (message , index , ptr , dim_size )
201
+ agg_m_ij = super ().aggregate (m_ij , index , ptr , dim_size )
202
+
203
+ return agg_message , agg_m_ij
204
+
205
+ def update (self , aggregated_inputs , x , pos , edge_index ):
157
206
"""
158
207
Update the node features and the node coordinates with the received
159
208
messages.
160
209
161
- :param torch.Tensor message : The message to be passed.
210
+ :param tuple( torch.Tensor) aggregated_inputs : The messages to be passed.
162
211
:param x: The node features.
163
212
:type x: torch.Tensor | LabelTensor
164
213
:param pos: The euclidean coordinates of the nodes.
@@ -167,10 +216,14 @@ def update(self, message, x, pos, edge_index):
167
216
:return: The updated node features and node positions.
168
217
:rtype: tuple(torch.Tensor, torch.Tensor)
169
218
"""
170
- # Update the node features
171
- x = self .update_net (torch .cat ((x , message ), dim = - 1 ))
219
+ # aggregated_inputs is tuple (agg_message, agg_m_ij)
220
+ agg_message , agg_m_ij = aggregated_inputs
221
+
222
+ # Update node features with aggregated m_ij
223
+ x = self .update_feat_net (torch .cat ((x , agg_m_ij ), dim = - 1 ))
224
+
225
+ # Degree for normalization of position updates
226
+ c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 ).clamp (min = 1 )
227
+ pos = pos + agg_message / c
172
228
173
- # Update the node positions
174
- c = degree (edge_index [0 ], pos .shape [0 ]).unsqueeze (- 1 )
175
- pos = pos + message / c
176
229
return x , pos
0 commit comments