Skip to content

Commit bc7d224

Browse files
committed
Update tests, need handling for radamw with older PyTorch, need to back-off basic test LR in mars?
1 parent 7d3146b commit bc7d224

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

tests/test_optim.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from copy import deepcopy
1010

1111
import torch
12+
from IPython.testing.decorators import skip_win32
1213
from torch.testing._internal.common_utils import TestCase
1314
from torch.nn import Parameter
1415

@@ -299,27 +300,39 @@ def test_optim_factory(optimizer):
299300
opt_info = get_optimizer_info(optimizer)
300301
assert isinstance(opt_info, OptimInfo)
301302

302-
if not opt_info.second_order: # basic tests don't support second order right now
303-
# test basic cases that don't need specific tuning via factory test
304-
_test_basic_cases(
305-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
306-
)
307-
_test_basic_cases(
308-
lambda weight, bias: create_optimizer_v2(
309-
_build_params_dict(weight, bias, lr=1e-2),
310-
optimizer,
311-
lr=1e-3)
312-
)
313-
_test_basic_cases(
314-
lambda weight, bias: create_optimizer_v2(
315-
_build_params_dict_single(weight, bias, lr=1e-2),
316-
optimizer,
317-
lr=1e-3)
318-
)
319-
_test_basic_cases(
320-
lambda weight, bias: create_optimizer_v2(
321-
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
322-
)
303+
lr = (1e-3, 1e-2, 1e-2, 1e-2)
304+
if optimizer in ('mars',):
305+
lr = (1e-3, 1e-3, 1e-3, 1e-3)
306+
307+
try:
308+
if not opt_info.second_order: # basic tests don't support second order right now
309+
# test basic cases that don't need specific tuning via factory test
310+
_test_basic_cases(
311+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0])
312+
)
313+
_test_basic_cases(
314+
lambda weight, bias: create_optimizer_v2(
315+
_build_params_dict(weight, bias, lr=lr[1]),
316+
optimizer,
317+
lr=lr[1] / 10)
318+
)
319+
_test_basic_cases(
320+
lambda weight, bias: create_optimizer_v2(
321+
_build_params_dict_single(weight, bias, lr=lr[2]),
322+
optimizer,
323+
lr=lr[2] / 10)
324+
)
325+
_test_basic_cases(
326+
lambda weight, bias: create_optimizer_v2(
327+
_build_params_dict_single(weight, bias, lr=lr[3]),
328+
optimizer)
329+
)
330+
except TypeError as e:
331+
if 'radamw' in optimizer:
332+
pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.")
333+
else:
334+
raise e
335+
323336

324337

325338
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])

0 commit comments

Comments
 (0)