@@ -345,15 +345,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
345
345
OptimInfo (
346
346
name = 'sgd' ,
347
347
opt_class = optim .SGD ,
348
- description = 'Stochastic Gradient Descent with Nesterov momentum (default) ' ,
348
+ description = 'torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum' ,
349
349
has_eps = False ,
350
350
has_momentum = True ,
351
351
defaults = {'nesterov' : True }
352
352
),
353
353
OptimInfo (
354
354
name = 'momentum' ,
355
355
opt_class = optim .SGD ,
356
- description = 'Stochastic Gradient Descent with classical momentum' ,
356
+ description = 'torch.Optim Stochastic Gradient Descent (SGD) with classical momentum' ,
357
357
has_eps = False ,
358
358
has_momentum = True ,
359
359
defaults = {'nesterov' : False }
@@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
798
798
799
799
def get_optimizer_class (
800
800
name : str ,
801
- bind_defaults : bool = False ,
801
+ bind_defaults : bool = True ,
802
802
) -> Union [Type [optim .Optimizer ], OptimizerCallable ]:
803
803
"""Get optimizer class by name with option to bind default arguments.
804
804
@@ -821,17 +821,14 @@ def get_optimizer_class(
821
821
ValueError: If optimizer name is not found in registry
822
822
823
823
Examples:
824
- >>> # Get raw optimizer class
825
- >>> Adam = get_optimizer_class('adam')
826
- >>> opt = Adam(model.parameters(), lr=1e-3)
827
-
828
- >>> # Get optimizer with defaults bound
829
- >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
830
- >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
831
-
832
824
>>> # Get SGD with nesterov momentum default
833
- >>> SGD = get_optimizer_class('sgd', bind_defaults=True ) # nesterov=True bound
825
+ >>> SGD = get_optimizer_class('sgd') # nesterov=True bound
834
826
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
827
+
828
+ >>> # Get raw optimizer class
829
+ >>> SGD = get_optimizer_class('sgd')
830
+ >>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
831
+
835
832
"""
836
833
return default_registry .get_optimizer_class (name , bind_defaults = bind_defaults )
837
834
0 commit comments