Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Commit 6b2186b

Browse files
authored
Adding a guide to training loops and callbacks. (#612)
1 parent e2a0d98 commit 6b2186b

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

docs/site/_book.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ upper_tabs:
3030
- heading: "Machine learning models"
3131
- title: Datasets
3232
path: /swift/guide/datasets
33+
- title: Training loop
34+
path: /swift/guide/training_loop
3335
- title: Model checkpoints
3436
path: /swift/guide/checkpoints
3537
- title: Model summaries

docs/site/guide/training_loop.md

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Training loop
2+
3+
When training a machine learning model, it's common to have a loop where training data is ingested
4+
(or generated), batches run through a model, gradients obtained, and the model updated via an
5+
optimizer. While you can write a training loop of your own for each training application,
6+
Swift for TensorFlow provides an experimental training loop abstraction that may simplify this
7+
process.
8+
9+
The [`TrainingLoop`](https://github.com/tensorflow/swift-models/tree/main/TrainingLoop) module
10+
within [the models repository](https://github.com/tensorflow/swift-models) contains the current
11+
version of this experimental generalized training loop. It is structured in such a way as to
12+
integrate with dataset wrappers that conform to the Epochs API for easy data ingestion, and to
13+
automate the interaction of models, datasets, and optimizers with accelerator backends to achieve
14+
optimal performance. Heavy customization of the training process can be achieved through the use
15+
of callbacks.
16+
17+
Most image-based examples in the model repository have been converted to use this training loop
18+
abstraction, as well as the supervised text model training examples. However, the training loop may
19+
not be appropriate in its current design for all machine learning models.
20+
21+
The implementation of Swift for TensorFlow's generalized training loop is heavily influenced by
22+
[fastai's Learner](https://docs.fast.ai/learner.html). For more on their design, please refer to
23+
["fastai: A Layered API for Deep Learning"](https://arxiv.org/abs/2002.04688) and Sylvain Gugger's
24+
presentation
25+
["Fast.ai - An infinitely customizable training loop"](https://www.youtube.com/watch?v=roc-dOSeehM).
26+
27+
## Usage
28+
29+
The [ResNet-CIFAR10](https://github.com/tensorflow/swift-models/tree/main/Examples/ResNet-CIFAR10)
30+
example provides a good demonstration of how to use this training loop in practice. First, import
31+
the module:
32+
33+
```swift
34+
import TrainingLoop
35+
```
36+
37+
then choose an accelerator backend by setting up a `Device`. In this case, we'll select the X10
38+
XLA-based backend and use the first available accelerator:
39+
40+
```swift
41+
let device = Device.defaultXLA
42+
```
43+
44+
The next step is to configure the dataset, model, and optimizer to use with your training loop:
45+
46+
```swift
47+
let dataset = CIFAR10(batchSize: 10, on: device)
48+
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
49+
var optimizer = SGD(for: model, learningRate: 0.001)
50+
```
51+
52+
and then set up the training loop:
53+
54+
```swift
55+
var trainingLoop = TrainingLoop(
56+
training: dataset.training,
57+
validation: dataset.validation,
58+
optimizer: optimizer,
59+
lossFunction: softmaxCrossEntropy,
60+
metrics: [.accuracy])
61+
```
62+
63+
The training loop assumes that the dataset you're using conforms to the Epochs API, and allows you
64+
to specify which splits within the dataset to use for training and validation. Any loss function
65+
can be used once placed into a compatible wrapper, such as `softmaxCrossEntropy` is
66+
[here](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/LossFunctions.swift).
67+
68+
The current metrics that can be captured include:
69+
70+
- `loss`
71+
- `accuracy`
72+
- `top5Accuracy`
73+
- `matthewsCorrelationCoefficient`
74+
- `perplexity`
75+
76+
Finally, to perform training, you call the following:
77+
78+
```swift
79+
try! trainingLoop.fit(&model, epochs: 10, on: device)
80+
```
81+
82+
This will train the model for 10 epochs using the accelerator backend we specified. Statistics will
83+
be displayed during training to the console using an animated prompt.
84+
85+
## Callbacks
86+
87+
Customization of this generalized training loop occurs via the use of callbacks. These callbacks can
88+
be hooked into various points within the loop.
89+
90+
Several built-in callbacks provide functionality that can be added to any training loop. These
91+
include:
92+
93+
- Logging statistics to comma-separated-value (CSV) files
94+
- Adjusting the learning rate according to a custom schedule
95+
- Monitoring and graphing training progress via TensorBoard
96+
97+
In addition to these, you can create your own custom callbacks to add a range of additional
98+
functionality to a standard training loop.
99+
100+
### CSV logging
101+
102+
The [`CSVLogger`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/CSVLogger.swift)
103+
class encapsulates a callback that will write out training statistics in a comma-separated-value
104+
format to a file of your choosing. This file will start with columns labeled `epoch`, `batch`, and
105+
whatever metrics you have enabled within your training loop. One row will then be written for each
106+
batch, with the current values of those columns.
107+
108+
To add CSV logging to your training loop, add something like the following to an array of callbacks
109+
provided to the `callbacks:` parameter for your `TrainingLoop`:
110+
111+
```swift
112+
try! CSVLogger(path: "file.csv").log
113+
```
114+
115+
As an example, the [`LeNet-MNIST` sample](https://github.com/tensorflow/swift-models/blob/main/Examples/LeNet-MNIST/main.swift#L52)
116+
uses this within its training loop.
117+
118+
### Learning rate schedules
119+
120+
It's common when training a model to change the learning rate provided to an optimizer during the
121+
training process. This can be as simple as a linear decrease over time, or as complex as warmup and
122+
decline cycles described by complicated functions.
123+
124+
The [`learningRateScheduler`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/LearningRateScheduler/LearningRateScheduler.swift)
125+
callback provides the means of describing learning rate schedules composed of different segments,
126+
each with their own distinct shape. This is accomplished by defining a
127+
[`LearningRateSchedule`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/LearningRateScheduler/LearningRateSchedule.swift)
128+
composed of `ScheduleSegment`s that each have a `Shape` defined by a function, an initial learning
129+
rate, and a final learning rate.
130+
131+
For example, the [BERT-CoLA sample](https://github.com/tensorflow/swift-models/blob/main/Examples/BERT-CoLA/main.swift)
132+
uses a linear increase in the learning rate during a warmup period and a linear decrease after that.
133+
To do this, the learning rate schedule callback is defined as follows:
134+
135+
```swift
136+
learningRateScheduler(
137+
schedule: makeSchedule(
138+
[
139+
ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
140+
ScheduleSegment(shape: linear, endRate: 0)
141+
]
142+
)
143+
)
144+
```
145+
146+
The two `ScheduleSegment`s define a learning rate that starts at 0 and increases linearly to
147+
`peakLearningRate` over a series of 10 discrete steps, then starts at the final learning rate from
148+
the previous step and decreases linearly to 0 by the end of the training process.
149+
150+
### TensorBoard integration
151+
152+
[TensorBoard](https://www.tensorflow.org/tensorboard) is a powerful visualization tool for
153+
monitoring model training, analyzing training when completed, or comparing training runs. Swift for
154+
TensorFlow supports TensorBoard visualization through the use of the
155+
[`TensorBoard`](https://github.com/tensorflow/swift-models/tree/main/TensorBoard) module in the
156+
models repository, which provides callbacks that log training metrics.
157+
158+
The [GPT2-WikiText2](https://github.com/tensorflow/swift-models/tree/main/Examples/GPT2-WikiText2)
159+
sample illustrates how to add TensorBoard logging to your model training. First, import the
160+
`TensorBoard` module. Then it's as simple as adding `tensorBoardStatisticsLogger()` to your
161+
`TrainingLoop`'s `callbacks:` array.
162+
163+
By default, that will log each training run within a `run/tensorboard/stats` directory. To view this
164+
within Tensorboard, run
165+
166+
```sh
167+
tensorboard --logdir ./run/tensorboard/stats
168+
```
169+
170+
and TensorBoard should start a local server where you can view your training metrics. Training and
171+
validation results should be shown separately, and each run has a unique timestamp to allow for
172+
easy comparison between multiple runs of the same model.
173+
174+
The design of the Swift for TensorFlow TensorBoard integration was inspired by
175+
[tensorboardX](https://github.com/lanpa/tensorboardX). The TensorBoard callbacks directly create the
176+
appropriate event and summary protocol buffers and write them within a log file during training.
177+
178+
### Custom callbacks
179+
180+
In addition to the built-in callbacks described above, you have the ability to customize the
181+
function of training loops by creating your own callbacks. These callbacks are functions that
182+
have a signature similar to the following:
183+
184+
```swift
185+
func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
186+
{
187+
if event == .updateStart {
188+
...
189+
}
190+
}
191+
```
192+
193+
The training loop and associated state are passed in as the first parameter. The current part of
194+
the loop that the callback is responding to is provided via `event`. The training loop event has
195+
one of the following states, each corresponding to a different point in the loop's life cycle:
196+
197+
- `fitStart`
198+
- `fitEnd`
199+
- `epochStart`
200+
- `epochEnd`
201+
- `trainingStart`
202+
- `trainingEnd`
203+
- `validationStart`
204+
- `validationEnd`
205+
- `batchStart`
206+
- `batchEnd`
207+
- `updateStart`
208+
- `inferencePredictionEnd`
209+
210+
Your callback function can choose to activate its logic on any combination of above states, which
211+
allows for extracting data from or otherwise controlling the training loop in many ways.

0 commit comments

Comments
 (0)