Skip to content

Commit 2bb8319

Browse files
FilippoOlivodario-coscia
authored andcommitted
Fixes
1 parent 2885f9a commit 2bb8319

File tree

3 files changed

+14
-25
lines changed

3 files changed

+14
-25
lines changed

pina/data/data_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,6 @@ def __init__(
331331
self.pin_memory = pin_memory
332332

333333
# Collect data
334-
# collector = Collector(problem)
335-
# collector.store_fixed_data()
336-
# collector.store_sample_domains()
337334
problem.collect_data()
338335

339336
# Check if the splits are correct

pina/problem/abstract_problem.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,33 +38,25 @@ def __init__(self):
3838
self.domains[cond_name] = cond.domain
3939
cond.domain = cond_name
4040

41-
self._collect_data = {}
41+
self._collected_data = {}
4242

4343
@property
4444
def collected_data(self):
4545
"""
4646
Return the collected data from the problem's conditions.
4747
48-
:return: The collected data.
48+
:return: The collected data. Keys are condition names, and values are
49+
dictionaries containing the input points and the corresponding
50+
equations or target points.
4951
:rtype: dict
5052
"""
51-
if not self._collect_data:
53+
if not self._collected_data:
5254
raise RuntimeError(
5355
"You have to call collect_data() before accessing the data."
5456
)
55-
return self._collect_data
56-
57-
@collected_data.setter
58-
def collected_data(self, data):
59-
"""
60-
Set the collected data from the problem's conditions.
61-
62-
:param dict data: The collected data.
63-
"""
64-
self._collect_data = data
57+
return self._collected_data
6558

6659
# back compatibility 0.1
67-
6860
@property
6961
def input_pts(self):
7062
"""
@@ -75,11 +67,12 @@ def input_pts(self):
7567
:rtype: dict
7668
"""
7769
to_return = {}
78-
for cond_name, cond in self.conditions.items():
79-
if hasattr(cond, "input"):
80-
to_return[cond_name] = cond.input
81-
elif hasattr(cond, "domain"):
82-
to_return[cond_name] = self._discretised_domains[cond.domain]
70+
if self._collected_data is None:
71+
raise RuntimeError(
72+
"You have to call collect_data() before accessing the data."
73+
)
74+
for cond_name, data in self._collected_data.items():
75+
to_return[cond_name] = data["input"]
8376
return to_return
8477

8578
@property
@@ -332,4 +325,4 @@ def collect_data(self):
332325
keys = condition.__slots__
333326
values = [getattr(condition, name) for name in keys]
334327
data[condition_name] = dict(zip(keys, values))
335-
self.collected_data = data
328+
self._collected_data = data

tests/test_problem.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def test_aggregate_data():
114114
def test_wrong_aggregate_data():
115115
poisson_problem = Poisson()
116116
poisson_problem.discretise_domain(0, "random", domains=["D"])
117-
with pytest.raises(RuntimeError):
118-
poisson_problem.collected_data()
117+
assert not poisson_problem._collected_data
119118
with pytest.raises(RuntimeError):
120119
poisson_problem.collect_data()

0 commit comments

Comments
 (0)