forked from pytorch/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscheduler.h
112 lines (90 loc) · 2.96 KB
/
scheduler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// Copyright 2020-present pytorch-cpp Authors
#pragma once
#include <torch/torch.h>
#include <vector>
#include <algorithm>
namespace scheduler {
template<typename TOptimizer>
struct OptimizerOptionsMap {
};
template<>
struct OptimizerOptionsMap<torch::optim::Adam> {
using type = torch::optim::AdamOptions;
};
template<>
struct OptimizerOptionsMap<torch::optim::Adagrad> {
using type = torch::optim::AdagradOptions;
};
template<>
struct OptimizerOptionsMap<torch::optim::LBFGS> {
using type = torch::optim::LBFGSOptions;
};
template<>
struct OptimizerOptionsMap<torch::optim::RMSprop> {
using type = torch::optim::RMSpropOptions;
};
template<>
struct OptimizerOptionsMap<torch::optim::SGD> {
using type = torch::optim::SGDOptions;
};
/**
* Learning rate scheduler base.
*
* Based on the Python implementation at
* https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py.
* @tparam TOptimizer Optimizer type
*/
template<typename TOptimizer>
class LRScheduler {
public:
explicit LRScheduler(TOptimizer& optimizer, int64_t last_epoch = -1)
: optimizer_(optimizer), last_epoch_(last_epoch), base_lrs(get_current_lr()) {}
virtual std::vector<double> get_lr() = 0;
void step() {
++last_epoch_;
const auto values = get_lr();
auto ¶m_groups = optimizer_.param_groups();
for (decltype(param_groups.size()) i = 0; i != param_groups.size(); ++i) {
dynamic_cast<typename OptimizerOptionsMap<TOptimizer>::type &>(param_groups[i].options()).lr(values[i]);
}
}
virtual ~LRScheduler() = default;
protected:
TOptimizer& optimizer_;
int64_t last_epoch_;
std::vector<double> base_lrs;
std::vector<double> get_current_lr() {
std::vector<double> lrs;
lrs.reserve(optimizer_.param_groups().size());
for (auto ¶m_group : optimizer_.param_groups()) {
lrs.push_back(dynamic_cast<typename
OptimizerOptionsMap<TOptimizer>::type &>(param_group.options()).lr());
}
return lrs;
}
};
/**
* Step learning rate scheduler.
*
* Based on the python implementation at
* https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py.
* @tparam TOptimizer Optimizer type
*/
template<typename TOptimizer>
class StepLR : public LRScheduler<TOptimizer> {
public:
StepLR(TOptimizer& optimizer, int64_t step_size, double gamma = 0.1, int64_t last_epoch = -1)
: LRScheduler<TOptimizer>(optimizer, last_epoch), step_size_(step_size), gamma_(gamma) {}
std::vector<double> get_lr() override {
auto new_lr = this->get_current_lr();
if (this->last_epoch_ != 0 && (this->last_epoch_ % step_size_ == 0)) {
std::transform(new_lr.cbegin(), new_lr.cend(), new_lr.begin(),
[gamma_ = gamma_](auto value) { return value * gamma_; });
}
return new_lr;
}
private:
int64_t step_size_;
double gamma_;
};
} // namespace scheduler