|
9 | 9 | from copy import deepcopy
|
10 | 10 |
|
11 | 11 | import torch
|
| 12 | +from IPython.testing.decorators import skip_win32 |
12 | 13 | from torch.testing._internal.common_utils import TestCase
|
13 | 14 | from torch.nn import Parameter
|
14 | 15 |
|
@@ -299,27 +300,39 @@ def test_optim_factory(optimizer):
|
299 | 300 | opt_info = get_optimizer_info(optimizer)
|
300 | 301 | assert isinstance(opt_info, OptimInfo)
|
301 | 302 |
|
302 |
| - if not opt_info.second_order: # basic tests don't support second order right now |
303 |
| - # test basic cases that don't need specific tuning via factory test |
304 |
| - _test_basic_cases( |
305 |
| - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) |
306 |
| - ) |
307 |
| - _test_basic_cases( |
308 |
| - lambda weight, bias: create_optimizer_v2( |
309 |
| - _build_params_dict(weight, bias, lr=1e-2), |
310 |
| - optimizer, |
311 |
| - lr=1e-3) |
312 |
| - ) |
313 |
| - _test_basic_cases( |
314 |
| - lambda weight, bias: create_optimizer_v2( |
315 |
| - _build_params_dict_single(weight, bias, lr=1e-2), |
316 |
| - optimizer, |
317 |
| - lr=1e-3) |
318 |
| - ) |
319 |
| - _test_basic_cases( |
320 |
| - lambda weight, bias: create_optimizer_v2( |
321 |
| - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) |
322 |
| - ) |
| 303 | + lr = (1e-3, 1e-2, 1e-2, 1e-2) |
| 304 | + if optimizer in ('mars',): |
| 305 | + lr = (1e-3, 1e-3, 1e-3, 1e-3) |
| 306 | + |
| 307 | + try: |
| 308 | + if not opt_info.second_order: # basic tests don't support second order right now |
| 309 | + # test basic cases that don't need specific tuning via factory test |
| 310 | + _test_basic_cases( |
| 311 | + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0]) |
| 312 | + ) |
| 313 | + _test_basic_cases( |
| 314 | + lambda weight, bias: create_optimizer_v2( |
| 315 | + _build_params_dict(weight, bias, lr=lr[1]), |
| 316 | + optimizer, |
| 317 | + lr=lr[1] / 10) |
| 318 | + ) |
| 319 | + _test_basic_cases( |
| 320 | + lambda weight, bias: create_optimizer_v2( |
| 321 | + _build_params_dict_single(weight, bias, lr=lr[2]), |
| 322 | + optimizer, |
| 323 | + lr=lr[2] / 10) |
| 324 | + ) |
| 325 | + _test_basic_cases( |
| 326 | + lambda weight, bias: create_optimizer_v2( |
| 327 | + _build_params_dict_single(weight, bias, lr=lr[3]), |
| 328 | + optimizer) |
| 329 | + ) |
| 330 | + except TypeError as e: |
| 331 | + if 'radamw' in optimizer: |
| 332 | + pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.") |
| 333 | + else: |
| 334 | + raise e |
| 335 | + |
323 | 336 |
|
324 | 337 |
|
325 | 338 | #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
|
0 commit comments