This repository was archived by the owner on Feb 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathLearningRateScheduler.swift
76 lines (63 loc) · 2.44 KB
/
LearningRateScheduler.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import TensorFlow
/// Returns a TrainingLoop callback that will change the learning rate according to `schedule`.
public func learningRateScheduler<L: TrainingLoopProtocol>(
schedule: @escaping (Int) -> LearningRateSchedule,
biasCorrectionBeta: (Float, Float)? = nil
) -> TrainingLoopCallback<L> {
var totalStepCount: Int = 0
return { (loop, event) throws -> Void in
if event != .batchStart || Context.local.learningPhase == .inference { return }
if totalStepCount == 0 {
totalStepCount = loop.batchCount! * loop.epochCount!
}
let step = loop.batchIndex! + loop.epochIndex! * loop.batchCount!
var learningRate = schedule(totalStepCount)(step)
if let beta = biasCorrectionBeta {
learningRate *= sqrt(1 - pow(beta.1, Float(step))) / (beta.0 - pow(beta.1, Float(step)))
}
loop.optimizer.learningRate = learningRate as! L.Opt.Scalar
}
}
/// A segment of the learning rate schedule.
public struct ScheduleSegment {
/// Shape of the segment.
public var shape: Shape
/// The start learning rate.
public var startRate: Float?
/// The end learning rate.
public var endRate: Float
/// Count of steps across the segment.
public var stepCount: Int?
/// Creates a `stepCount`-step segment with shape `shape`, start rate `startRate` and end rate `endRate`.
public init(shape: Shape, startRate: Float? = nil, endRate: Float, stepCount: Int? = nil) {
self.shape = shape
self.startRate = startRate
self.endRate = endRate
self.stepCount = stepCount
}
}
/// Returns a function that returns a LearningRateSchedule given totalStepCount; the function
/// is constucted from an array of `schedules`.
public func makeSchedule(_ schedules: [ScheduleSegment]) -> (Int) -> LearningRateSchedule {
precondition(schedules.count > 0)
return { (totalStepCount: Int) -> LearningRateSchedule in
var lrs = LearningRateSchedule(startRate: schedules.first!.startRate ?? 0)
var lastEndStep = 0
for (i, s) in schedules.enumerated() {
var stepCount: Int
if i < schedules.count - 1 {
precondition(s.stepCount != nil)
stepCount = s.stepCount!
lastEndStep += (stepCount - 1)
} else {
if s.stepCount == nil {
stepCount = totalStepCount - lastEndStep
} else {
stepCount = s.stepCount!
}
}
lrs.appendSegment(stepCount: stepCount, shape: s.shape, startRate: s.startRate, endRate: s.endRate)
}
return lrs
}
}