Skip to content

Message Passing Module #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 1, 2025
Merged

Message Passing Module #516

merged 8 commits into from
Jun 1, 2025

Conversation

dario-coscia
Copy link
Collaborator

@dario-coscia dario-coscia commented Mar 21, 2025

Description

This PR fixes #515.

Here a tentative RoadMap

Checklist

  • Code follows the project’s Code Style Guidelines
  • Tests have been added or updated
  • Documentation has been updated if necessary
  • Pull request is linked to an open issue

@dario-coscia dario-coscia added enhancement New feature or request pr-to-fix Label for PR that needs modification labels Mar 21, 2025
@dario-coscia dario-coscia linked an issue Mar 21, 2025 that may be closed by this pull request
Copy link
Contributor

github-actions bot commented Mar 21, 2025

badge

Code Coverage Summary

Filename                                                       Stmts    Miss  Cover    Missing
-----------------------------------------------------------  -------  ------  -------  ---------------------------------------------------------------------------------------------------------------------
__init__.py                                                        7       0  100.00%
graph.py                                                         114      11  90.35%   99-100, 112, 124, 126, 142, 144, 166, 169, 182, 271
label_tensor.py                                                  251      32  87.25%   81, 121, 144-148, 165, 177, 182, 188-193, 273, 280, 332, 334, 348, 444-447, 490, 537, 629, 649-651, 664-673, 688, 710
operator.py                                                       72       5  93.06%   250-268, 459
operators.py                                                       6       6  0.00%    3-12
plotter.py                                                         1       1  0.00%    3
trainer.py                                                        75       6  92.00%   195-204, 293, 314, 318, 322
utils.py                                                          60       8  86.67%   113, 150, 153, 156, 192-195
adaptive_function/__init__.py                                      3       0  100.00%
adaptive_function/adaptive_function.py                            55       0  100.00%
adaptive_function/adaptive_function_interface.py                  51       6  88.24%   98, 141, 148-151
adaptive_functions/__init__.py                                     6       6  0.00%    3-12
callback/__init__.py                                               5       0  100.00%
callback/linear_weight_update_callback.py                         28       1  96.43%   63
callback/optimizer_callback.py                                    22       1  95.45%   34
callback/processing_callback.py                                   49       5  89.80%   42-43, 73, 168, 171
callback/refinement/__init__.py                                    3       0  100.00%
callback/refinement/r3_refinement.py                              28       1  96.43%   88
callback/refinement/refinement_interface.py                       50       5  90.00%   32, 59, 67, 72, 78
callbacks/__init__.py                                              6       6  0.00%    3-12
condition/__init__.py                                              7       0  100.00%
condition/condition.py                                            35       8  77.14%   23, 127-128, 131-132, 135-136, 151
condition/condition_interface.py                                  37       4  89.19%   31, 76, 100, 122
condition/data_condition.py                                       26       1  96.15%   56
condition/domain_equation_condition.py                            19       0  100.00%
condition/input_equation_condition.py                             44       1  97.73%   129
condition/input_target_condition.py                               44       1  97.73%   125
data/__init__.py                                                   3       0  100.00%
data/data_module.py                                              201      22  89.05%   41-52, 132, 172, 193, 232, 313-317, 323-327, 399, 466, 546, 637, 639
data/dataset.py                                                   85       7  91.76%   42, 123-126, 256, 307
domain/__init__.py                                                10       0  100.00%
domain/cartesian.py                                              112      10  91.07%   37, 47, 75-76, 92, 97, 103, 246, 256, 264
domain/difference_domain.py                                       25       2  92.00%   54, 87
domain/domain_interface.py                                        20       5  75.00%   37-41
domain/ellipsoid.py                                              104      24  76.92%   52, 56, 127, 250-257, 269-282, 286-287, 290, 295
domain/exclusion_domain.py                                        28       1  96.43%   86
domain/intersection_domain.py                                     28       1  96.43%   85
domain/operation_interface.py                                     26       1  96.15%   88
domain/simplex.py                                                 72      14  80.56%   62, 207-225, 246-247, 251, 256
domain/union_domain.py                                            25       1  96.00%   43
equation/__init__.py                                               4       0  100.00%
equation/equation.py                                              11       0  100.00%
equation/equation_factory.py                                      24      10  58.33%   37, 62-75, 97-110, 132-145
equation/equation_interface.py                                     4       0  100.00%
equation/system_equation.py                                       22       0  100.00%
geometry/__init__.py                                               7       7  0.00%    3-15
loss/__init__.py                                                   7       0  100.00%
loss/loss_interface.py                                            17       2  88.24%   45, 51
loss/lp_loss.py                                                   15       0  100.00%
loss/ntk_weighting.py                                             26       0  100.00%
loss/power_loss.py                                                15       0  100.00%
loss/scalar_weighting.py                                          16       0  100.00%
loss/weighting_interface.py                                        6       0  100.00%
model/__init__.py                                                 10       0  100.00%
model/average_neural_operator.py                                  31       2  93.55%   73, 82
model/deeponet.py                                                 93      13  86.02%   187-190, 209, 240, 283, 293, 303, 313, 323, 333, 488, 498
model/feed_forward.py                                             89      11  87.64%   58, 195, 200, 278-292
model/fourier_neural_operator.py                                  78      10  87.18%   96-100, 110, 155-159, 218, 220, 242, 342
model/graph_neural_operator.py                                    40       2  95.00%   58, 60
model/kernel_neural_operator.py                                   34       6  82.35%   83-84, 103-104, 123-124
model/low_rank_neural_operator.py                                 27       2  92.59%   89, 98
model/multi_feed_forward.py                                       12       5  58.33%   25-31
model/spline.py                                                   89      37  58.43%   30, 41-66, 69, 128-132, 135, 159-177, 180
model/block/__init__.py                                           12       0  100.00%
model/block/average_neural_operator_block.py                      12       0  100.00%
model/block/convolution.py                                        64      13  79.69%   77, 81, 85, 91, 97, 111, 114, 151, 161, 171, 181, 191, 201
model/block/convolution_2d.py                                    146      27  81.51%   155, 162, 282, 314, 379-433, 456
model/block/embedding.py                                          48       7  85.42%   93, 143-146, 155, 168
model/block/fourier_block.py                                      31       0  100.00%
model/block/gno_block.py                                          22       4  81.82%   73-77, 87
model/block/integral.py                                           18       4  77.78%   22-25, 71
model/block/low_rank_block.py                                     24       0  100.00%
model/block/orthogonal.py                                         37       0  100.00%
model/block/pod_block.py                                          73      10  86.30%   55-58, 70, 83, 113, 148-153, 187, 212
model/block/rbf_block.py                                         179      25  86.03%   18, 42, 53, 64, 75, 86, 97, 223, 280, 282, 298, 301, 329, 335, 363, 367, 511-524
model/block/residual.py                                           46       0  100.00%
model/block/spectral.py                                           83       4  95.18%   132, 140, 262, 270
model/block/stride.py                                             28       7  75.00%   55, 58, 61, 67, 72-74
model/block/utils_convolution.py                                  22       3  86.36%   58-60
model/block/message_passing/__init__.py                            6       0  100.00%
model/block/message_passing/deep_tensor_network_block.py          21       0  100.00%
model/block/message_passing/en_equivariant_network_block.py       39       0  100.00%
model/block/message_passing/interaction_network_block.py          23       0  100.00%
model/block/message_passing/radial_field_network_block.py         20       0  100.00%
model/block/message_passing/schnet_block.py                       25       0  100.00%
model/layers/__init__.py                                           6       6  0.00%    3-12
optim/__init__.py                                                  5       0  100.00%
optim/optimizer_interface.py                                       7       0  100.00%
optim/scheduler_interface.py                                       7       0  100.00%
optim/torch_optimizer.py                                          14       0  100.00%
optim/torch_scheduler.py                                          19       2  89.47%   5-6
problem/__init__.py                                                6       0  100.00%
problem/abstract_problem.py                                      117      18  84.62%   39-40, 59-70, 115-120, 149, 161, 179, 253, 257, 286
problem/inverse_problem.py                                        22       0  100.00%
problem/parametric_problem.py                                      8       1  87.50%   29
problem/spatial_problem.py                                         8       0  100.00%
problem/time_dependent_problem.py                                  8       0  100.00%
problem/zoo/__init__.py                                            8       0  100.00%
problem/zoo/advection.py                                          33       7  78.79%   36-38, 52, 108-110
problem/zoo/allen_cahn.py                                         20       6  70.00%   20-22, 34-36
problem/zoo/diffusion_reaction.py                                 29       5  82.76%   94-104
problem/zoo/helmholtz.py                                          30       6  80.00%   36-42, 103-107
problem/zoo/inverse_poisson_2d_square.py                          31       0  100.00%
problem/zoo/poisson_2d_square.py                                  19       3  84.21%   65-70
problem/zoo/supervised_problem.py                                 11       0  100.00%
solver/__init__.py                                                 6       0  100.00%
solver/garom.py                                                  107       2  98.13%   129-130
solver/solver.py                                                 188      10  94.68%   192, 215, 287, 290-291, 350, 432, 515, 556, 562
solver/ensemble_solver/__init__.py                                 4       0  100.00%
solver/ensemble_solver/ensemble_pinn.py                           23       1  95.65%   104
solver/ensemble_solver/ensemble_solver_interface.py               27       0  100.00%
solver/ensemble_solver/ensemble_supervised.py                      9       0  100.00%
solver/physics_informed_solver/__init__.py                         8       0  100.00%
solver/physics_informed_solver/causal_pinn.py                     47       3  93.62%   157, 166-167
solver/physics_informed_solver/competitive_pinn.py                58       0  100.00%
solver/physics_informed_solver/gradient_pinn.py                   17       0  100.00%
solver/physics_informed_solver/pinn.py                            18       0  100.00%
solver/physics_informed_solver/pinn_interface.py                  47       1  97.87%   130
solver/physics_informed_solver/rba_pinn.py                        35       3  91.43%   155-158
solver/physics_informed_solver/self_adaptive_pinn.py              90       3  96.67%   315-318
solver/supervised_solver/__init__.py                               4       0  100.00%
solver/supervised_solver/reduced_order_model.py                   24       1  95.83%   137
solver/supervised_solver/supervised.py                             7       0  100.00%
solver/supervised_solver/supervised_solver_interface.py           25       0  100.00%
solvers/__init__.py                                                6       6  0.00%    3-12
solvers/pinns/__init__.py                                          6       6  0.00%    3-12
TOTAL                                                           4663     504  89.19%

Results for commit: fbc0382

Minimum allowed coverage is 80.123%

♻️ This comment has been updated with latest results

@dario-coscia
Copy link
Collaborator Author

Hi @AleDinve @GiovanniCanali ! How is it going with this?

@GiovanniCanali
Copy link
Collaborator

Hi @AleDinve @GiovanniCanali ! How is it going with this?

Hi @dario-coscia, I need to fix some minor issues with InteractionNetwork, and then I will fix EGNN. Also, tests will be implemented. @AleDinve agreed to take care of the remaining classes.

@AleDinve
Copy link
Collaborator

Hi @AleDinve @GiovanniCanali ! How is it going with this?

Hi @dario-coscia, I need to fix some minor issues with InteractionNetwork, and then I will fix EGNN. Also, tests will be implemented. @AleDinve agreed to take care of the remaining classes.

Yes, I confirm, I will have a tentative implementation of the classes assigned to me by the weekend.

Copy link
Collaborator Author

@dario-coscia dario-coscia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good @GiovanniCanali and @AleDinve ! I made few comments on the implementation of the various blocks.

I think we should think about inserting inside utils.py a simple function that checks integer types and values. For example (very minimalistic):

def check_values(value, positive=True, strict=True):
   if positive and strict:
       assert value >= 0
   .....

this would reduce a lot of lines of code inside the blocks.

@GiovanniCanali GiovanniCanali force-pushed the messagepassing branch 3 times, most recently from 5b7a708 to 1525549 Compare May 29, 2025 21:28
@GiovanniCanali GiovanniCanali self-requested a review May 30, 2025 09:02
@GiovanniCanali GiovanniCanali mentioned this pull request May 30, 2025
4 tasks
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds implementations of various message-passing neural network blocks, a new utility for validating integer inputs, and corresponding tests.

  • Implements five blocks: InteractionNetworkBlock, DeepTensorNetworkBlock, EnEquivariantNetworkBlock, RadialFieldNetworkBlock, SchnetBlock.
  • Introduces check_positive_integer in pina/utils.py and adds tests in tests/test_utils.py.
  • Expands the test suite under tests/test_messagepassing to cover constructor, forward, and backward behavior for each block.

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tests/test_utils.py Reorganized imports; added test_check_positive_integer.
tests/test_messagepassing/test_schnet_block.py Added constructor, forward, and backward tests for SchnetBlock.
tests/test_messagepassing/test_radial_field_network_block.py Added constructor, forward, and backward tests for RadialFieldNetworkBlock.
tests/test_messagepassing/test_interaction_network_block.py Added constructor, forward, and backward tests for InteractionNetworkBlock.
tests/test_messagepassing/test_equivariant_network_block.py Added constructor, forward, and backward tests for EnEquivariantNetworkBlock.
tests/test_messagepassing/test_deep_tensor_network_block.py Added constructor, forward, and backward tests for DeepTensorNetworkBlock.
pina/utils.py Added check_positive_integer utility.
pina/model/block/message_passing/schnet_block.py Implemented SchnetBlock.
pina/model/block/message_passing/radial_field_network_block.py Implemented RadialFieldNetworkBlock.
pina/model/block/message_passing/interaction_network_block.py Implemented InteractionNetworkBlock.
pina/model/block/message_passing/egnn_block.py Implemented EnEquivariantNetworkBlock.
pina/model/block/message_passing/deep_tensor_network_block.py Implemented DeepTensorNetworkBlock.
pina/model/block/message_passing/init.py Updated __all__ to include new blocks.

@dario-coscia dario-coscia added pr-to-fix Label for PR that needs modification and removed pr-to-review Label for PR that are ready to been reviewed labels May 30, 2025
@GiovanniCanali GiovanniCanali added pr-to-review Label for PR that are ready to been reviewed and removed pr-to-fix Label for PR that needs modification labels May 30, 2025
Copy link
Collaborator Author

@dario-coscia dario-coscia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice overall! I think the PR is almost ready.

I left few comments because I find some differences in EGNN and ShNet implementation wrt the paper.

Also I wanted to ask, should we test invariance and Equiv. in tests? I think we should

@GiovanniCanali
Copy link
Collaborator

@dario-coscia loops and hyperlinks have been fixed.

There are two last things left:

  • fix the forward of the Equivariant Network. I will take care of it asap.
  • implement equivariance tests for the very same model.

@GiovanniCanali GiovanniCanali added pr-to-fix Label for PR that needs modification pr-to-review Label for PR that are ready to been reviewed and removed pr-to-review Label for PR that are ready to been reviewed pr-to-fix Label for PR that needs modification labels May 30, 2025
@GiovanniCanali
Copy link
Collaborator

Fixed EGNN and added equivariance test. Ready for final review @dario-coscia.

Copy link
Collaborator Author

@dario-coscia dario-coscia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me everything is great now!

I made a simple fix in EGNN to avoid self.message... to be more pyg compliant. I also added invariance tests for SchNEt and Equivariance for RBF following what @GiovanniCanali did for EGNN!

@AleDinve
Copy link
Collaborator

AleDinve commented Jun 1, 2025 via email

@dario-coscia dario-coscia force-pushed the messagepassing branch 2 times, most recently from a3e7f9f to 104c639 Compare June 1, 2025 11:44
@GiovanniCanali GiovanniCanali self-requested a review June 1, 2025 11:50
@dario-coscia dario-coscia merged commit 2101d79 into dev Jun 1, 2025
18 checks passed
dario-coscia added a commit that referenced this pull request Jun 13, 2025
* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <[email protected]>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <[email protected]>

---------

Co-authored-by: giovanni <[email protected]>
Co-authored-by: AleDinve <[email protected]>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <[email protected]>
Co-authored-by: Giovanni Canali <[email protected]>
Co-authored-by: giovanni <[email protected]>
Co-authored-by: AleDinve <[email protected]>
@GiovanniCanali GiovanniCanali deleted the messagepassing branch June 14, 2025 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request pr-to-review Label for PR that are ready to been reviewed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Message Passing Module
3 participants