|
| 1 | +# Training loop |
| 2 | + |
| 3 | +When training a machine learning model, it's common to have a loop where training data is ingested |
| 4 | +(or generated), batches run through a model, gradients obtained, and the model updated via an |
| 5 | +optimizer. While you can write a training loop of your own for each training application, |
| 6 | +Swift for TensorFlow provides an experimental training loop abstraction that may simplify this |
| 7 | +process. |
| 8 | + |
| 9 | +The [`TrainingLoop`](https://github.com/tensorflow/swift-models/tree/main/TrainingLoop) module |
| 10 | +within [the models repository](https://github.com/tensorflow/swift-models) contains the current |
| 11 | +version of this experimental generalized training loop. It is structured in such a way as to |
| 12 | +integrate with dataset wrappers that conform to the Epochs API for easy data ingestion, and to |
| 13 | +automate the interaction of models, datasets, and optimizers with accelerator backends to achieve |
| 14 | +optimal performance. Heavy customization of the training process can be achieved through the use |
| 15 | +of callbacks. |
| 16 | + |
| 17 | +Most image-based examples in the model repository have been converted to use this training loop |
| 18 | +abstraction, as well as the supervised text model training examples. However, the training loop may |
| 19 | +not be appropriate in its current design for all machine learning models. |
| 20 | + |
| 21 | +The implementation of Swift for TensorFlow's generalized training loop is heavily influenced by |
| 22 | +[fastai's Learner](https://docs.fast.ai/learner.html). For more on their design, please refer to |
| 23 | +["fastai: A Layered API for Deep Learning"](https://arxiv.org/abs/2002.04688) and Sylvain Gugger's |
| 24 | +presentation |
| 25 | +["Fast.ai - An infinitely customizable training loop"](https://www.youtube.com/watch?v=roc-dOSeehM). |
| 26 | + |
| 27 | +## Usage |
| 28 | + |
| 29 | +The [ResNet-CIFAR10](https://github.com/tensorflow/swift-models/tree/main/Examples/ResNet-CIFAR10) |
| 30 | +example provides a good demonstration of how to use this training loop in practice. First, import |
| 31 | +the module: |
| 32 | + |
| 33 | +```swift |
| 34 | +import TrainingLoop |
| 35 | +``` |
| 36 | + |
| 37 | +then choose an accelerator backend by setting up a `Device`. In this case, we'll select the X10 |
| 38 | +XLA-based backend and use the first available accelerator: |
| 39 | + |
| 40 | +```swift |
| 41 | +let device = Device.defaultXLA |
| 42 | +``` |
| 43 | + |
| 44 | +The next step is to configure the dataset, model, and optimizer to use with your training loop: |
| 45 | + |
| 46 | +```swift |
| 47 | +let dataset = CIFAR10(batchSize: 10, on: device) |
| 48 | +var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false) |
| 49 | +var optimizer = SGD(for: model, learningRate: 0.001) |
| 50 | +``` |
| 51 | + |
| 52 | +and then set up the training loop: |
| 53 | + |
| 54 | +```swift |
| 55 | +var trainingLoop = TrainingLoop( |
| 56 | + training: dataset.training, |
| 57 | + validation: dataset.validation, |
| 58 | + optimizer: optimizer, |
| 59 | + lossFunction: softmaxCrossEntropy, |
| 60 | + metrics: [.accuracy]) |
| 61 | +``` |
| 62 | + |
| 63 | +The training loop assumes that the dataset you're using conforms to the Epochs API, and allows you |
| 64 | +to specify which splits within the dataset to use for training and validation. Any loss function |
| 65 | +can be used once placed into a compatible wrapper, such as `softmaxCrossEntropy` is |
| 66 | +[here](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/LossFunctions.swift). |
| 67 | + |
| 68 | +The current metrics that can be captured include: |
| 69 | + |
| 70 | +- `loss` |
| 71 | +- `accuracy` |
| 72 | +- `top5Accuracy` |
| 73 | +- `matthewsCorrelationCoefficient` |
| 74 | +- `perplexity` |
| 75 | + |
| 76 | +Finally, to perform training, you call the following: |
| 77 | + |
| 78 | +```swift |
| 79 | +try! trainingLoop.fit(&model, epochs: 10, on: device) |
| 80 | +``` |
| 81 | + |
| 82 | +This will train the model for 10 epochs using the accelerator backend we specified. Statistics will |
| 83 | +be displayed during training to the console using an animated prompt. |
| 84 | + |
| 85 | +## Callbacks |
| 86 | + |
| 87 | +Customization of this generalized training loop occurs via the use of callbacks. These callbacks can |
| 88 | +be hooked into various points within the loop. |
| 89 | + |
| 90 | +Several built-in callbacks provide functionality that can be added to any training loop. These |
| 91 | +include: |
| 92 | + |
| 93 | +- Logging statistics to comma-separated-value (CSV) files |
| 94 | +- Adjusting the learning rate according to a custom schedule |
| 95 | +- Monitoring and graphing training progress via TensorBoard |
| 96 | + |
| 97 | +In addition to these, you can create your own custom callbacks to add a range of additional |
| 98 | +functionality to a standard training loop. |
| 99 | + |
| 100 | +### CSV logging |
| 101 | + |
| 102 | +The [`CSVLogger`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/CSVLogger.swift) |
| 103 | +class encapsulates a callback that will write out training statistics in a comma-separated-value |
| 104 | +format to a file of your choosing. This file will start with columns labeled `epoch`, `batch`, and |
| 105 | +whatever metrics you have enabled within your training loop. One row will then be written for each |
| 106 | +batch, with the current values of those columns. |
| 107 | + |
| 108 | +To add CSV logging to your training loop, add something like the following to an array of callbacks |
| 109 | +provided to the `callbacks:` parameter for your `TrainingLoop`: |
| 110 | + |
| 111 | +```swift |
| 112 | +try! CSVLogger(path: "file.csv").log |
| 113 | +``` |
| 114 | + |
| 115 | +As an example, the [`LeNet-MNIST` sample](https://github.com/tensorflow/swift-models/blob/main/Examples/LeNet-MNIST/main.swift#L52) |
| 116 | +uses this within its training loop. |
| 117 | + |
| 118 | +### Learning rate schedules |
| 119 | + |
| 120 | +It's common when training a model to change the learning rate provided to an optimizer during the |
| 121 | +training process. This can be as simple as a linear decrease over time, or as complex as warmup and |
| 122 | +decline cycles described by complicated functions. |
| 123 | + |
| 124 | +The [`learningRateScheduler`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/LearningRateScheduler/LearningRateScheduler.swift) |
| 125 | +callback provides the means of describing learning rate schedules composed of different segments, |
| 126 | +each with their own distinct shape. This is accomplished by defining a |
| 127 | +[`LearningRateSchedule`](https://github.com/tensorflow/swift-models/blob/main/TrainingLoop/Callbacks/LearningRateScheduler/LearningRateSchedule.swift) |
| 128 | +composed of `ScheduleSegment`s that each have a `Shape` defined by a function, an initial learning |
| 129 | +rate, and a final learning rate. |
| 130 | + |
| 131 | +For example, the [BERT-CoLA sample](https://github.com/tensorflow/swift-models/blob/main/Examples/BERT-CoLA/main.swift) |
| 132 | +uses a linear increase in the learning rate during a warmup period and a linear decrease after that. |
| 133 | +To do this, the learning rate schedule callback is defined as follows: |
| 134 | + |
| 135 | +```swift |
| 136 | +learningRateScheduler( |
| 137 | + schedule: makeSchedule( |
| 138 | + [ |
| 139 | + ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10), |
| 140 | + ScheduleSegment(shape: linear, endRate: 0) |
| 141 | + ] |
| 142 | + ) |
| 143 | +) |
| 144 | +``` |
| 145 | + |
| 146 | +The two `ScheduleSegment`s define a learning rate that starts at 0 and increases linearly to |
| 147 | +`peakLearningRate` over a series of 10 discrete steps, then starts at the final learning rate from |
| 148 | +the previous step and decreases linearly to 0 by the end of the training process. |
| 149 | + |
| 150 | +### TensorBoard integration |
| 151 | + |
| 152 | +[TensorBoard](https://www.tensorflow.org/tensorboard) is a powerful visualization tool for |
| 153 | +monitoring model training, analyzing training when completed, or comparing training runs. Swift for |
| 154 | +TensorFlow supports TensorBoard visualization through the use of the |
| 155 | +[`TensorBoard`](https://github.com/tensorflow/swift-models/tree/main/TensorBoard) module in the |
| 156 | +models repository, which provides callbacks that log training metrics. |
| 157 | + |
| 158 | +The [GPT2-WikiText2](https://github.com/tensorflow/swift-models/tree/main/Examples/GPT2-WikiText2) |
| 159 | +sample illustrates how to add TensorBoard logging to your model training. First, import the |
| 160 | +`TensorBoard` module. Then it's as simple as adding `tensorBoardStatisticsLogger()` to your |
| 161 | +`TrainingLoop`'s `callbacks:` array. |
| 162 | + |
| 163 | +By default, that will log each training run within a `run/tensorboard/stats` directory. To view this |
| 164 | +within Tensorboard, run |
| 165 | + |
| 166 | +```sh |
| 167 | +tensorboard --logdir ./run/tensorboard/stats |
| 168 | +``` |
| 169 | + |
| 170 | +and TensorBoard should start a local server where you can view your training metrics. Training and |
| 171 | +validation results should be shown separately, and each run has a unique timestamp to allow for |
| 172 | +easy comparison between multiple runs of the same model. |
| 173 | + |
| 174 | +The design of the Swift for TensorFlow TensorBoard integration was inspired by |
| 175 | +[tensorboardX](https://github.com/lanpa/tensorboardX). The TensorBoard callbacks directly create the |
| 176 | +appropriate event and summary protocol buffers and write them within a log file during training. |
| 177 | + |
| 178 | +### Custom callbacks |
| 179 | + |
| 180 | +In addition to the built-in callbacks described above, you have the ability to customize the |
| 181 | +function of training loops by creating your own callbacks. These callbacks are functions that |
| 182 | +have a signature similar to the following: |
| 183 | + |
| 184 | +```swift |
| 185 | +func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws |
| 186 | +{ |
| 187 | + if event == .updateStart { |
| 188 | + ... |
| 189 | + } |
| 190 | +} |
| 191 | +``` |
| 192 | + |
| 193 | +The training loop and associated state are passed in as the first parameter. The current part of |
| 194 | +the loop that the callback is responding to is provided via `event`. The training loop event has |
| 195 | +one of the following states, each corresponding to a different point in the loop's life cycle: |
| 196 | + |
| 197 | +- `fitStart` |
| 198 | +- `fitEnd` |
| 199 | +- `epochStart` |
| 200 | +- `epochEnd` |
| 201 | +- `trainingStart` |
| 202 | +- `trainingEnd` |
| 203 | +- `validationStart` |
| 204 | +- `validationEnd` |
| 205 | +- `batchStart` |
| 206 | +- `batchEnd` |
| 207 | +- `updateStart` |
| 208 | +- `inferencePredictionEnd` |
| 209 | + |
| 210 | +Your callback function can choose to activate its logic on any combination of above states, which |
| 211 | +allows for extracting data from or otherwise controlling the training loop in many ways. |
0 commit comments