Skip to content

Commit d0161f3

Browse files
committed
Small optim factory tweak. default bind_defaults=True for get_optimizer_class
1 parent ef062ee commit d0161f3

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

timm/optim/_optim_factory.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
345345
OptimInfo(
346346
name='sgd',
347347
opt_class=optim.SGD,
348-
description='Stochastic Gradient Descent with Nesterov momentum (default)',
348+
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
349349
has_eps=False,
350350
has_momentum=True,
351351
defaults={'nesterov': True}
352352
),
353353
OptimInfo(
354354
name='momentum',
355355
opt_class=optim.SGD,
356-
description='Stochastic Gradient Descent with classical momentum',
356+
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
357357
has_eps=False,
358358
has_momentum=True,
359359
defaults={'nesterov': False}
@@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
798798

799799
def get_optimizer_class(
800800
name: str,
801-
bind_defaults: bool = False,
801+
bind_defaults: bool = True,
802802
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
803803
"""Get optimizer class by name with option to bind default arguments.
804804
@@ -821,17 +821,14 @@ def get_optimizer_class(
821821
ValueError: If optimizer name is not found in registry
822822
823823
Examples:
824-
>>> # Get raw optimizer class
825-
>>> Adam = get_optimizer_class('adam')
826-
>>> opt = Adam(model.parameters(), lr=1e-3)
827-
828-
>>> # Get optimizer with defaults bound
829-
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
830-
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
831-
832824
>>> # Get SGD with nesterov momentum default
833-
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
825+
>>> SGD = get_optimizer_class('sgd') # nesterov=True bound
834826
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
827+
828+
>>> # Get raw optimizer class
829+
>>> SGD = get_optimizer_class('sgd')
830+
>>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
831+
835832
"""
836833
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
837834

0 commit comments

Comments
 (0)