8
8
9
9
10
10
class EnEquivariantNetworkBlock (MessagePassing ):
11
+ """
12
+ Implementation of the E(n) Equivariant Graph Neural Network block.
13
+ This block is used to perform message-passing between nodes and edges in a
14
+ graph neural network, following the scheme proposed by Satorras et al. in
15
+ 2021. It serves as an inner block in a larger graph neural network
16
+ architecture.
17
+ The message between two nodes connected by an edge is computed by applying a
18
+ linear transformation to the sender node features and the edge features,
19
+ together with the squared euclidean distance between the sender and
20
+ recipient node positions, followed by a non-linear activation function.
21
+ Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
22
+ min, max, or product).
23
+ The update step is performed by applying another MLP to the concatenation of
24
+ the incoming messages and the node features. Here, also the node
25
+ positions are updated by adding the incoming messages divided by the
26
+ degree of the recipient node.
27
+ .. seealso::
28
+ **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
29
+ (2021). *E(n) Equivariant Graph Neural Networks.*
30
+ In International Conference on Machine Learning.
31
+ DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
32
+ """
33
+
11
34
def __init__ (
12
35
self ,
13
36
node_feature_dim ,
@@ -21,15 +44,49 @@ def __init__(
21
44
node_dim = - 2 ,
22
45
flow = "source_to_target" ,
23
46
):
47
+ """
48
+ Initialization of the :class:`EnEquivariantNetworkBlock` class.
49
+ :param int node_feature_dim: The dimension of the node features.
50
+ :param int edge_feature_dim: The dimension of the edge features.
51
+ :param int pos_dim: The dimension of the position features.
52
+ :param int hidden_dim: The dimension of the hidden features.
53
+ Default is 64.
54
+ :param int n_message_layers: The number of layers in the message
55
+ network. Default is 2.
56
+ :param int n_update_layers: The number of layers in the update network.
57
+ Default is 2.
58
+ :param torch.nn.Module activation: The activation function.
59
+ Default is :class:`torch.nn.SiLU`.
60
+ :param str aggr: The aggregation scheme to use for message passing.
61
+ Available options are "add", "mean", "min", "max", "mul".
62
+ See :class:`torch_geometric.nn.MessagePassing` for more details.
63
+ Default is "add".
64
+ :param int node_dim: The axis along which to propagate. Default is -2.
65
+ :param str flow: The direction of message passing. Available options
66
+ are "source_to_target" and "target_to_source".
67
+ The "source_to_target" flow means that messages are sent from
68
+ the source node to the target node, while the "target_to_source"
69
+ flow means that messages are sent from the target node to the
70
+ source node. See :class:`torch_geometric.nn.MessagePassing` for more
71
+ details. Default is "source_to_target".
72
+ :raises AssertionError: If `node_feature_dim` is not a positive integer.
73
+ :raises AssertionError: If `edge_feature_dim` is a negative integer.
74
+ :raises AssertionError: If `pos_dim` is not a positive integer.
75
+ :raises AssertionError: If `hidden_dim` is not a positive integer.
76
+ :raises AssertionError: If `n_message_layers` is not a positive integer.
77
+ :raises AssertionError: If `n_update_layers` is not a positive integer.
78
+ """
24
79
super ().__init__ (aggr = aggr , node_dim = node_dim , flow = flow )
25
80
81
+ # Check values
26
82
check_positive_integer (node_feature_dim , strict = True )
27
83
check_positive_integer (edge_feature_dim , strict = False )
28
84
check_positive_integer (pos_dim , strict = True )
29
85
check_positive_integer (hidden_dim , strict = True )
30
86
check_positive_integer (n_message_layers , strict = True )
31
87
check_positive_integer (n_update_layers , strict = True )
32
88
89
+ # Layer for computing the message
33
90
self .message_net = FeedForward (
34
91
input_dimensions = 2 * node_feature_dim + edge_feature_dim + 1 ,
35
92
output_dimensions = pos_dim ,
@@ -38,6 +95,7 @@ def __init__(
38
95
func = activation ,
39
96
)
40
97
98
+ # Layer for updating the node features
41
99
self .update_feat_net = FeedForward (
42
100
input_dimensions = node_feature_dim + pos_dim ,
43
101
output_dimensions = node_feature_dim ,
@@ -46,6 +104,8 @@ def __init__(
46
104
func = activation ,
47
105
)
48
106
107
+ # Layer for updating the node positions
108
+ # The output dimension is set to 1 for equivariant updates
49
109
self .update_pos_net = FeedForward (
50
110
input_dimensions = pos_dim ,
51
111
output_dimensions = 1 ,
@@ -87,18 +147,21 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
87
147
:param edge_attr: The edge attributes.
88
148
:type edge_attr: torch.Tensor | LabelTensor
89
149
:return: The message to be passed.
90
- :rtype: torch.Tensor
150
+ :rtype: tuple( torch.Tensor, torch.Tensor)
91
151
"""
152
+ # Compute the euclidean distance between the sender and recipient nodes
92
153
diff = pos_i - pos_j
93
154
dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
94
155
156
+ # Compute the message input
95
157
if edge_attr is None :
96
158
input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
97
159
else :
98
160
input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
99
161
100
- m_ij = self .message_net (input_ ) # message features
101
- message = diff * self .update_pos_net (m_ij ) # equivariant message
162
+ # Compute the messages and their equivariant counterpart
163
+ m_ij = self .message_net (input_ )
164
+ message = diff * self .update_pos_net (m_ij )
102
165
103
166
return message , m_ij
104
167
@@ -112,20 +175,20 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):
112
175
113
176
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
114
177
aggregate.
115
- :param torch.Tensor | LabelTensor index: The indices of target nodes
116
- for each message. This tensor specifies which node each message
117
- is aggregated into.
118
- :param torch.Tensor | LabelTensor ptr: Optional tensor to specify
119
- the slices of messages for each node (used in some aggregation
120
- strategies).
178
+ :param index: The indices of target nodes for each message. This tensor
179
+ specifies which node each message is aggregated into.
180
+ :type index: torch.Tensor | LabelTensor
181
+ :param ptr: Optional tensor to specify the slices of messages for each
182
+ node (used in some aggregation strategies). Default is None.
183
+ :type ptr: torch.Tensor | LabelTensor
121
184
:param int dim_size: Optional size of the output dimension, i.e.,
122
- number of nodes.
123
- :return: Tuple of aggregated tensors corresponding to
124
- (aggregated messages for position updates, aggregated messages for
125
- feature updates).
185
+ number of nodes. Default is None.
186
+ :return: Tuple of aggregated tensors corresponding to (aggregated
187
+ messages for position updates, aggregated messages for feature
188
+ updates).
126
189
:rtype: tuple(torch.Tensor, torch.Tensor)
127
190
"""
128
- # inputs is tuple (message, m_ij), we want to aggregate separately
191
+ # Unpack the messages from the inputs
129
192
message , m_ij = inputs
130
193
131
194
# Aggregate messages as usual using self.aggr method
0 commit comments