1
1
from torch .optim import lr_scheduler
2
- from functools import partial
3
2
from omegaconf .dictconfig import DictConfig
4
3
import logging
5
4
from src .utils .config import merge_omega_conf
6
5
7
6
8
7
log = logging .getLogger (__name__ )
9
8
9
+
10
10
def repr (self , scheduler_params = {}):
11
11
return "{}({})" .format (self .__class__ .__name__ , scheduler_params )
12
12
13
- class SchedulerWrapper ():
14
13
14
+ class LRScheduler :
15
15
def __init__ (self , scheduler , scheduler_params ):
16
16
self ._scheduler = scheduler
17
17
self ._scheduler_params = scheduler_params
@@ -22,7 +22,7 @@ def scheduler(self):
22
22
23
23
@property
24
24
def scheduler_opt (self ):
25
- return self ._scheduler ._scheduler_opt
25
+ return self ._scheduler ._scheduler_opt
26
26
27
27
def __repr__ (self ):
28
28
return "{}({})" .format (self ._scheduler .__class__ .__name__ , self ._scheduler_params )
@@ -36,6 +36,7 @@ def state_dict(self):
36
36
def load_state_dict (self , state_dict ):
37
37
self ._scheduler .load_state_dict (state_dict )
38
38
39
+
39
40
def instantiate_scheduler (optimizer , scheduler_opt ):
40
41
"""Return a learning rate scheduler
41
42
Parameters:
@@ -45,37 +46,15 @@ def instantiate_scheduler(optimizer, scheduler_opt):
45
46
opt.params contains the scheduler_params to construct the scheduler
46
47
See https://pytorch.org/docs/stable/optim.html for more details.
47
48
"""
48
- base_lr = optimizer .defaults ['lr' ]
49
+
50
+ scheduler_cls_name = getattr (scheduler_opt , "class" )
51
+ scheduler_cls = getattr (lr_scheduler , scheduler_cls_name )
49
52
scheduler_params = scheduler_opt .params
50
- if scheduler_opt .lr_policy == 'lambda_rule' :
51
- if scheduler_opt .rule == "step_decay" :
52
- lr_lambda = lambda e : max (
53
- scheduler_params .lr_decay ** (e // scheduler_params .decay_step ),
54
- scheduler_params .lr_clip / base_lr ,
55
- )
56
- elif scheduler_opt .rule == "exponential_decay" :
57
- lr_lambda = lambda e : max (
58
- eval (scheduler_params .gamma ) ** (e / scheduler_params .decay_step ),
59
- scheduler_params .lr_clip / base_lr ,
60
- )
61
- else :
62
- raise NotImplementedError
63
- scheduler = lr_scheduler .LambdaLR (optimizer , lr_lambda = lr_lambda )
64
-
65
- elif scheduler_opt .lr_policy == 'step' :
66
- scheduler = lr_scheduler .StepLR (optimizer , ** scheduler_params )
67
-
68
- elif scheduler_opt .lr_policy == 'plateau' :
69
- scheduler = lr_scheduler .ReduceLROnPlateau (optimizer , ** scheduler_params )
70
- scheduler_params = merge_omega_conf (scheduler_params , {"metric_name" : scheduler_opt .metric_name })
71
- setattr (scheduler , "metric_name" , scheduler_opt .metric_name )
72
-
73
- elif scheduler_opt .lr_policy == 'cosine' :
74
- scheduler = lr_scheduler .CosineAnnealingLR (optimizer , ** scheduler_params )
75
- else :
76
- return NotImplementedError ('learning rate policy [%s] is not implemented' , scheduler_opt .lr_policy )
77
-
53
+
54
+ if scheduler_cls_name .lower () == "ReduceLROnPlateau" .lower ():
55
+ raise NotImplementedError ("This scheduler is not fully supported yet" )
56
+
57
+ scheduler = scheduler_cls (optimizer , ** scheduler_params )
78
58
# used to re_create the scheduler
79
59
setattr (scheduler , "_scheduler_opt" , scheduler_opt )
80
-
81
- return SchedulerWrapper (scheduler , scheduler_params )
60
+ return LRScheduler (scheduler , scheduler_params )
0 commit comments