|
| 1 | +# Model checkpoints |
| 2 | + |
| 3 | +The ability to save and restore the state of a model is vital for a number of applications, such |
| 4 | +as in transfer learning or for performing inference using pretrained models. Saving the |
| 5 | +parameters of a model (weights, biases, etc.) in a checkpoint file or directory is one way to |
| 6 | +accomplish this. |
| 7 | + |
| 8 | +This module provides a high-level interface for loading and saving |
| 9 | +[TensorFlow v2 format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/tensor_bundle/tensor_bundle.h) |
| 10 | +checkpoints, as well as lower-level components that write to and read from this file format. |
| 11 | + |
| 12 | + |
| 13 | +## Loading and saving simple models |
| 14 | + |
| 15 | +By conforming to the `Checkpointable` protocol, many simple models can be serialized to |
| 16 | +checkpoints without any additional code: |
| 17 | + |
| 18 | +```swift |
| 19 | +import Checkpoints |
| 20 | +import ImageClassificationModels |
| 21 | + |
| 22 | +extension LeNet: Checkpointable {} |
| 23 | + |
| 24 | +var model = LeNet() |
| 25 | + |
| 26 | +... |
| 27 | + |
| 28 | +try model.writeCheckpoint(to: directory, name: "LeNet") |
| 29 | +``` |
| 30 | +and then that same checkpoint can be read by using: |
| 31 | + |
| 32 | +```swift |
| 33 | +try model.readCheckpoint(from: directory, name: "LeNet") |
| 34 | +``` |
| 35 | +This default implementation for model loading and saving will use a path-based naming |
| 36 | +scheme for each tensor in the model that is based on the names of the properties within the |
| 37 | +model structs. For example, the weights and biases within the first convolution in |
| 38 | +[the LeNet-5 model](https://github.com/tensorflow/swift-models/blob/main/Models/ImageClassification/LeNet-5.swift#L26) |
| 39 | +will be saved with the names `conv1/filter` and `conv1/bias`, respectively. When loading, |
| 40 | +the checkpoint reader will search for tensors with these names. |
| 41 | + |
| 42 | +## Customizing model loading and saving |
| 43 | + |
| 44 | +If you want to have greater control over which tensors are saved and loaded, or the naming |
| 45 | +of those tensors, the `Checkpointable` protocol offers a few points of customization. |
| 46 | + |
| 47 | +To ignore properties on certain types, you can provide an implementation of |
| 48 | +`ignoredTensorPaths` on your model that returns a Set of strings in the form of |
| 49 | +`Type.property`. For example, to ignore the `scale` property on every Attention layer, you |
| 50 | +could return `["Attention.scale"]`. |
| 51 | + |
| 52 | +By default, a forward slash is used to separate each deeper level in a model. This can be |
| 53 | +customized by implementing `checkpointSeparator` on your model and providing a new |
| 54 | +string to use for this separator. |
| 55 | + |
| 56 | +Finally, for the greatest degree of customization in tensor naming, you can implement |
| 57 | +`tensorNameMap` and provide a function that maps from the default string name generated |
| 58 | +for a tensor in the model to a desired string name in the checkpoint. Most commonly, this |
| 59 | +will be used to interoperate with checkpoints generated with other frameworks, each of which |
| 60 | +have their own naming conventions and model structures. A custom mapping function gives |
| 61 | +the greatest degree of customization for how these tensors are named. |
| 62 | + |
| 63 | +Some standard helper functions are provided, like the default |
| 64 | +`CheckpointWriter.identityMap` (which simply uses the automatically generated tensor |
| 65 | +path name for checkpoints), or the `CheckpointWriter.lookupMap(table:)` function, |
| 66 | +which can build a mapping from a dictionary. |
| 67 | + |
| 68 | +For an example of how custom mapping can be accomplished, please see |
| 69 | +[the GPT-2 model](https://github.com/tensorflow/swift-models/blob/main/Models/Text/GPT2/CheckpointWriter.swift), |
| 70 | +which uses a mapping function to match the exact naming scheme used for OpenAI's |
| 71 | +checkpoints. |
| 72 | + |
| 73 | +## The CheckpointReader and CheckpointWriter components |
| 74 | + |
| 75 | +For checkpoint writing, the extension provided by the `Checkpointable` protocol |
| 76 | +uses reflection and keypaths to iterate over a model's properties and generate a dictionary |
| 77 | +that maps string tensor paths to Tensor values. This dictionary is provided to an underlying |
| 78 | +`CheckpointWriter`, along with a directory in which to write the checkpoint. That |
| 79 | +`CheckpointWriter` handles the task of generating the on-disk checkpoint from that |
| 80 | +dictionary. |
| 81 | + |
| 82 | +The reverse of this process is reading, where a `CheckpointReader` is given the location of |
| 83 | +an on-disk checkpoint directory. It then reads from that checkpoint and forms a dictionary that |
| 84 | +maps the names of tensors within the checkpoint with their saved values. This dictionary is |
| 85 | +used to replace the current tensors in a model with the ones in this dictionary. |
| 86 | + |
| 87 | +For both loading and saving, the `Checkpointable` protocol maps the string paths to tensors |
| 88 | +to corresponding on-disk tensor names using the above-described mapping function. |
| 89 | + |
| 90 | +If the `Checkpointable` protocol lacks needed functionality, or more control is desired over |
| 91 | +the checkpoint loading and saving process, the `CheckpointReader` and |
| 92 | +`CheckpointWriter` classes can be used by themselves. |
| 93 | + |
| 94 | +## The TensorFlow v2 checkpoint format |
| 95 | + |
| 96 | +The TensorFlow v2 checkpoint format, as briefly described in |
| 97 | +[this header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/tensor_bundle/tensor_bundle.h), |
| 98 | +is the second generation format for TensorFlow model checkpoints. This second-generation |
| 99 | +format has been in use since late 2016, and has a number of improvements over the v1 |
| 100 | +checkpoint format. TensorFlow SavedModels use v2 checkpoints within them to save model |
| 101 | +parameters. |
| 102 | + |
| 103 | +A TensorFlow v2 checkpoint consists of a directory with a structure like the following: |
| 104 | + |
| 105 | +``` |
| 106 | +checkpoint/modelname.index |
| 107 | +checkpoint/modelname.data-00000-of-00002 |
| 108 | +checkpoint/modelname.data-00001-of-00002 |
| 109 | +``` |
| 110 | + |
| 111 | +where the first file stores the metadata for the checkpoint and the remaining files are binary |
| 112 | +shards holding the serialized parameters for the model. |
| 113 | + |
| 114 | +The index metadata file contains the types, sizes, locations, and string names of all serialized |
| 115 | +tensors contained in the shards. That index file is the most structurally complex part of the |
| 116 | +checkpoint, and is based on `tensorflow::table`, which is itself based on SSTable / LevelDB. |
| 117 | +This index file is composed of a series of key-value pairs, where the keys are strings and the |
| 118 | +values are protocol buffers. The strings are sorted and prefix-compressed. For example: if |
| 119 | +the first entry is `conv1/weight` and next `conv1/bias`, the second entry only uses the |
| 120 | +`bias` part. |
| 121 | + |
| 122 | +This overall index file is sometimes compressed using |
| 123 | +[Snappy compression](https://github.com/google/snappy). The |
| 124 | +`SnappyDecompression.swift` file provides a native Swift implementation of Snappy |
| 125 | +decompression from a compressed Data instance. |
| 126 | + |
| 127 | +The index header metadata and tensor metadata are encoded as protocol buffers and |
| 128 | +encoded / decoded directly via [Swift Protobuf](https://github.com/apple/swift-protobuf). |
| 129 | + |
| 130 | +The `CheckpointIndexReader` and `CheckpointIndexWriter` classes handle loading |
| 131 | +and saving these index files as part of the overarching `CheckpointReader` and |
| 132 | +`CheckpointWriter` classes. The latter use the index files as basis for determining what to |
| 133 | +read from and write to the structurally simpler binary shards that contain the tensor data. |
0 commit comments