Skip to content

Commit 383a92a

Browse files
fixed doc
1 parent 104c639 commit 383a92a

File tree

1 file changed

+77
-14
lines changed

1 file changed

+77
-14
lines changed

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,29 @@
88

99

1010
class EnEquivariantNetworkBlock(MessagePassing):
11+
"""
12+
Implementation of the E(n) Equivariant Graph Neural Network block.
13+
This block is used to perform message-passing between nodes and edges in a
14+
graph neural network, following the scheme proposed by Satorras et al. in
15+
2021. It serves as an inner block in a larger graph neural network
16+
architecture.
17+
The message between two nodes connected by an edge is computed by applying a
18+
linear transformation to the sender node features and the edge features,
19+
together with the squared euclidean distance between the sender and
20+
recipient node positions, followed by a non-linear activation function.
21+
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
22+
min, max, or product).
23+
The update step is performed by applying another MLP to the concatenation of
24+
the incoming messages and the node features. Here, also the node
25+
positions are updated by adding the incoming messages divided by the
26+
degree of the recipient node.
27+
.. seealso::
28+
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
29+
(2021). *E(n) Equivariant Graph Neural Networks.*
30+
In International Conference on Machine Learning.
31+
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
32+
"""
33+
1134
def __init__(
1235
self,
1336
node_feature_dim,
@@ -21,15 +44,49 @@ def __init__(
2144
node_dim=-2,
2245
flow="source_to_target",
2346
):
47+
"""
48+
Initialization of the :class:`EnEquivariantNetworkBlock` class.
49+
:param int node_feature_dim: The dimension of the node features.
50+
:param int edge_feature_dim: The dimension of the edge features.
51+
:param int pos_dim: The dimension of the position features.
52+
:param int hidden_dim: The dimension of the hidden features.
53+
Default is 64.
54+
:param int n_message_layers: The number of layers in the message
55+
network. Default is 2.
56+
:param int n_update_layers: The number of layers in the update network.
57+
Default is 2.
58+
:param torch.nn.Module activation: The activation function.
59+
Default is :class:`torch.nn.SiLU`.
60+
:param str aggr: The aggregation scheme to use for message passing.
61+
Available options are "add", "mean", "min", "max", "mul".
62+
See :class:`torch_geometric.nn.MessagePassing` for more details.
63+
Default is "add".
64+
:param int node_dim: The axis along which to propagate. Default is -2.
65+
:param str flow: The direction of message passing. Available options
66+
are "source_to_target" and "target_to_source".
67+
The "source_to_target" flow means that messages are sent from
68+
the source node to the target node, while the "target_to_source"
69+
flow means that messages are sent from the target node to the
70+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
71+
details. Default is "source_to_target".
72+
:raises AssertionError: If `node_feature_dim` is not a positive integer.
73+
:raises AssertionError: If `edge_feature_dim` is a negative integer.
74+
:raises AssertionError: If `pos_dim` is not a positive integer.
75+
:raises AssertionError: If `hidden_dim` is not a positive integer.
76+
:raises AssertionError: If `n_message_layers` is not a positive integer.
77+
:raises AssertionError: If `n_update_layers` is not a positive integer.
78+
"""
2479
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
2580

81+
# Check values
2682
check_positive_integer(node_feature_dim, strict=True)
2783
check_positive_integer(edge_feature_dim, strict=False)
2884
check_positive_integer(pos_dim, strict=True)
2985
check_positive_integer(hidden_dim, strict=True)
3086
check_positive_integer(n_message_layers, strict=True)
3187
check_positive_integer(n_update_layers, strict=True)
3288

89+
# Layer for computing the message
3390
self.message_net = FeedForward(
3491
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
3592
output_dimensions=pos_dim,
@@ -38,6 +95,7 @@ def __init__(
3895
func=activation,
3996
)
4097

98+
# Layer for updating the node features
4199
self.update_feat_net = FeedForward(
42100
input_dimensions=node_feature_dim + pos_dim,
43101
output_dimensions=node_feature_dim,
@@ -46,6 +104,8 @@ def __init__(
46104
func=activation,
47105
)
48106

107+
# Layer for updating the node positions
108+
# The output dimension is set to 1 for equivariant updates
49109
self.update_pos_net = FeedForward(
50110
input_dimensions=pos_dim,
51111
output_dimensions=1,
@@ -87,18 +147,21 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
87147
:param edge_attr: The edge attributes.
88148
:type edge_attr: torch.Tensor | LabelTensor
89149
:return: The message to be passed.
90-
:rtype: torch.Tensor
150+
:rtype: tuple(torch.Tensor, torch.Tensor)
91151
"""
152+
# Compute the euclidean distance between the sender and recipient nodes
92153
diff = pos_i - pos_j
93154
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
94155

156+
# Compute the message input
95157
if edge_attr is None:
96158
input_ = torch.cat((x_i, x_j, dist), dim=-1)
97159
else:
98160
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
99161

100-
m_ij = self.message_net(input_) # message features
101-
message = diff * self.update_pos_net(m_ij) # equivariant message
162+
# Compute the messages and their equivariant counterpart
163+
m_ij = self.message_net(input_)
164+
message = diff * self.update_pos_net(m_ij)
102165

103166
return message, m_ij
104167

@@ -112,20 +175,20 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):
112175
113176
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
114177
aggregate.
115-
:param torch.Tensor | LabelTensor index: The indices of target nodes
116-
for each message. This tensor specifies which node each message
117-
is aggregated into.
118-
:param torch.Tensor | LabelTensor ptr: Optional tensor to specify
119-
the slices of messages for each node (used in some aggregation
120-
strategies).
178+
:param index: The indices of target nodes for each message. This tensor
179+
specifies which node each message is aggregated into.
180+
:type index: torch.Tensor | LabelTensor
181+
:param ptr: Optional tensor to specify the slices of messages for each
182+
node (used in some aggregation strategies). Default is None.
183+
:type ptr: torch.Tensor | LabelTensor
121184
:param int dim_size: Optional size of the output dimension, i.e.,
122-
number of nodes.
123-
:return: Tuple of aggregated tensors corresponding to
124-
(aggregated messages for position updates, aggregated messages for
125-
feature updates).
185+
number of nodes. Default is None.
186+
:return: Tuple of aggregated tensors corresponding to (aggregated
187+
messages for position updates, aggregated messages for feature
188+
updates).
126189
:rtype: tuple(torch.Tensor, torch.Tensor)
127190
"""
128-
# inputs is tuple (message, m_ij), we want to aggregate separately
191+
# Unpack the messages from the inputs
129192
message, m_ij = inputs
130193

131194
# Aggregate messages as usual using self.aggr method

0 commit comments

Comments
 (0)