Skip to content

Commit ec49976

Browse files
committed
fix connection problem.zoo
1 parent 6b355b4 commit ec49976

File tree

6 files changed

+71
-36
lines changed

6 files changed

+71
-36
lines changed

pina/problem/zoo/inverse_poisson_2d_square.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,51 @@
11
"""Formulation of the inverse Poisson problem in a square domain."""
22

3+
import warnings
34
import requests
45
import torch
56
from io import BytesIO
7+
from requests.exceptions import RequestException
68
from ... import Condition
79
from ... import LabelTensor
810
from ...operator import laplacian
911
from ...domain import CartesianDomain
1012
from ...equation import Equation, FixedValue
1113
from ...problem import SpatialProblem, InverseProblem
14+
from ...utils import custom_warning_format
15+
16+
warnings.formatwarning = custom_warning_format
17+
warnings.filterwarnings("always", category=ResourceWarning)
18+
19+
20+
def _load_tensor_from_url(url, labels):
21+
"""
22+
Downloads a tensor file from a URL and wraps it in a LabelTensor.
23+
24+
This function fetches a `.pth` file containing tensor data, extracts it,
25+
and returns it as a LabelTensor using the specified labels. If the file
26+
cannot be retrieved (e.g., no internet connection), a warning is issued
27+
and None is returned.
28+
29+
:param str url: URL to the remote `.pth` tensor file.
30+
:param list[str] | tuple[str] labels: Labels for the resulting LabelTensor.
31+
:return: A LabelTensor object if successful, otherwise None.
32+
:rtype: LabelTensor | None
33+
"""
34+
try:
35+
response = requests.get(url)
36+
response.raise_for_status()
37+
tensor = torch.load(
38+
BytesIO(response.content), weights_only=False
39+
).tensor.detach()
40+
return LabelTensor(tensor, labels)
41+
except RequestException as e:
42+
print(
43+
"Could not download data for 'InversePoisson2DSquareProblem' "
44+
f"from '{url}'. "
45+
f"Reason: {e}. Skipping data loading.",
46+
ResourceWarning,
47+
)
48+
return None
1249

1350

1451
def laplace_equation(input_, output_, params_):
@@ -29,35 +66,26 @@ def laplace_equation(input_, output_, params_):
2966
return delta_u - force_term
3067

3168

32-
# URL of the file
33-
url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pts_0.5_0.5"
34-
# Download the file
35-
response = requests.get(url)
36-
response.raise_for_status()
37-
file_like_object = BytesIO(response.content)
38-
# Set the data
39-
input_data = LabelTensor(
40-
torch.load(file_like_object, weights_only=False).tensor.detach(),
41-
["x", "y", "mu1", "mu2"],
69+
# loading data
70+
input_url = (
71+
"https://github.com/mathLab/PINA/raw/refs/heads/master"
72+
"/tutorials/tutorial7/data/pts_0.5_0.5"
4273
)
43-
44-
# URL of the file
45-
url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
46-
# Download the file
47-
response = requests.get(url)
48-
response.raise_for_status()
49-
file_like_object = BytesIO(response.content)
50-
# Set the data
51-
output_data = LabelTensor(
52-
torch.load(file_like_object, weights_only=False).tensor.detach(), ["u"]
74+
output_url = (
75+
"https://github.com/mathLab/PINA/raw/refs/heads/master"
76+
"/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
5377
)
78+
input_data = _load_tensor_from_url(input_url, ["x", "y", "mu1", "mu2"])
79+
output_data = _load_tensor_from_url(output_url, ["u"])
5480

5581

5682
class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
5783
r"""
5884
Implementation of the inverse 2-dimensional Poisson problem in the square
5985
domain :math:`[0, 1] \times [0, 1]`,
6086
with unknown parameter domain :math:`[-1, 1] \times [-1, 1]`.
87+
The `"data"` condition is added only if the required files are
88+
downloaded successfully.
6189
6290
:Example:
6391
>>> problem = InversePoisson2DSquareProblem()
@@ -83,5 +111,7 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
83111
"g3": Condition(domain="g3", equation=FixedValue(0.0)),
84112
"g4": Condition(domain="g4", equation=FixedValue(0.0)),
85113
"D": Condition(domain="D", equation=Equation(laplace_equation)),
86-
"data": Condition(input=input_data, target=output_data),
87114
}
115+
116+
if input_data is not None and input_data is not None:
117+
conditions["data"] = Condition(input=input_data, target=output_data)

tests/test_solver/test_competitive_pinn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
inverse_problem.discretise_domain(10)
2525

2626
# reduce the number of data points to speed up testing
27-
data_condition = inverse_problem.conditions["data"]
28-
data_condition.input = data_condition.input[:10]
29-
data_condition.target = data_condition.target[:10]
27+
if hasattr(inverse_problem.conditions, "data"):
28+
data_condition = inverse_problem.conditions["data"]
29+
data_condition.input = data_condition.input[:10]
30+
data_condition.target = data_condition.target[:10]
3031

3132
# add input-output condition to test supervised learning
3233
input_pts = torch.rand(10, len(problem.input_variables))

tests/test_solver/test_gradient_pinn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ class DummyTimeProblem(TimeDependentProblem):
3535
inverse_problem.discretise_domain(10)
3636

3737
# reduce the number of data points to speed up testing
38-
data_condition = inverse_problem.conditions["data"]
39-
data_condition.input = data_condition.input[:10]
40-
data_condition.target = data_condition.target[:10]
38+
if hasattr(inverse_problem.conditions, "data"):
39+
data_condition = inverse_problem.conditions["data"]
40+
data_condition.input = data_condition.input[:10]
41+
data_condition.target = data_condition.target[:10]
4142

4243
# add input-output condition to test supervised learning
4344
input_pts = torch.rand(10, len(problem.input_variables))

tests/test_solver/test_pinn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
inverse_problem.discretise_domain(10)
2525

2626
# reduce the number of data points to speed up testing
27-
data_condition = inverse_problem.conditions["data"]
28-
data_condition.input = data_condition.input[:10]
29-
data_condition.target = data_condition.target[:10]
27+
if hasattr(inverse_problem.conditions, "data"):
28+
data_condition = inverse_problem.conditions["data"]
29+
data_condition.input = data_condition.input[:10]
30+
data_condition.target = data_condition.target[:10]
3031

3132
# add input-output condition to test supervised learning
3233
input_pts = torch.rand(10, len(problem.input_variables))

tests/test_solver/test_rba_pinn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
inverse_problem.discretise_domain(10)
2424

2525
# reduce the number of data points to speed up testing
26-
data_condition = inverse_problem.conditions["data"]
27-
data_condition.input = data_condition.input[:10]
28-
data_condition.target = data_condition.target[:10]
26+
if hasattr(inverse_problem.conditions, "data"):
27+
data_condition = inverse_problem.conditions["data"]
28+
data_condition.input = data_condition.input[:10]
29+
data_condition.target = data_condition.target[:10]
2930

3031
# add input-output condition to test supervised learning
3132
input_pts = torch.rand(10, len(problem.input_variables))

tests/test_solver/test_self_adaptive_pinn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
inverse_problem.discretise_domain(10)
2525

2626
# reduce the number of data points to speed up testing
27-
data_condition = inverse_problem.conditions["data"]
28-
data_condition.input = data_condition.input[:10]
29-
data_condition.target = data_condition.target[:10]
27+
if hasattr(inverse_problem.conditions, "data"):
28+
data_condition = inverse_problem.conditions["data"]
29+
data_condition.input = data_condition.input[:10]
30+
data_condition.target = data_condition.target[:10]
3031

3132
# add input-output condition to test supervised learning
3233
input_pts = torch.rand(10, len(problem.input_variables))

0 commit comments

Comments
 (0)