Skip to content

Commit a1b0660

Browse files
committed
update pydens
1 parent 7c95083 commit a1b0660

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

pydens/model_torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.autograd import grad
1010
from tqdm import tqdm
1111

12-
from .batchflow.batchflow.models.torch.layers import ConvBlock # pylint: disable=import-error
12+
from batchflow.models.torch.layers import Block # pylint: disable=import-error
1313

1414

1515
current_model = ContextVar("current_model")
@@ -165,7 +165,7 @@ def __init__(self, ndims, initial_condition=None, boundary_condition=None, domai
165165

166166
# Assemble conv-block.
167167
fake_inputs = torch.rand((2, self.total), dtype=torch.float32)
168-
self.conv_block = ConvBlock(inputs=fake_inputs, **kwargs)
168+
self.conv_block = Block(inputs=fake_inputs, **kwargs)
169169

170170
def forward(self, xs):
171171
u = self.conv_block(xs)
@@ -260,7 +260,7 @@ def pde(f, x, e):
260260
go into the `model`-class.
261261
262262
Note:
263-
`ConvBlockModel` is based on `ConvBlock` from framework
263+
`ConvBlockModel` is based on `Block` from framework
264264
"`BatchFlow <https://github.com/analysiscenter/batchflow>`_".
265265
266266
constraints : sequence or callable

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
'tqdm>=4.19.7',
3232
'scipy>=0.19.1',
3333
'scikit-image>=0.13.1',
34+
'batchflow>=0.8.0',
3435
'numba>=0.42',
3536
],
3637
classifiers=[

0 commit comments

Comments
 (0)