Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Commit 4854e4c

Browse files
authored
Adding a checkpointing guide and a link to the WordSeg design. (#608)
1 parent f3e704f commit 4854e4c

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

docs/site/_book.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ upper_tabs:
3030
- heading: "Machine learning models"
3131
- title: Datasets
3232
path: /swift/guide/datasets
33+
- title: Model checkpoints
34+
path: /swift/guide/checkpoints
3335
- title: Model summaries
3436
path: /swift/guide/model_summary
37+
- title: "Behind the scenes: WordSeg"
38+
path: https://docs.google.com/document/d/1NlFH0_89gB_qggtgzJIKYHL2xPI3IQjWjv18pnT1M0E
39+
status: external
3540
- title: Swift for TensorFlow model garden
3641
path: https://github.com/tensorflow/swift-models
3742
status: external

docs/site/guide/checkpoints.md

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Model checkpoints
2+
3+
The ability to save and restore the state of a model is vital for a number of applications, such
4+
as in transfer learning or for performing inference using pretrained models. Saving the
5+
parameters of a model (weights, biases, etc.) in a checkpoint file or directory is one way to
6+
accomplish this.
7+
8+
This module provides a high-level interface for loading and saving
9+
[TensorFlow v2 format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/tensor_bundle/tensor_bundle.h)
10+
checkpoints, as well as lower-level components that write to and read from this file format.
11+
12+
13+
## Loading and saving simple models
14+
15+
By conforming to the `Checkpointable` protocol, many simple models can be serialized to
16+
checkpoints without any additional code:
17+
18+
```swift
19+
import Checkpoints
20+
import ImageClassificationModels
21+
22+
extension LeNet: Checkpointable {}
23+
24+
var model = LeNet()
25+
26+
...
27+
28+
try model.writeCheckpoint(to: directory, name: "LeNet")
29+
```
30+
and then that same checkpoint can be read by using:
31+
32+
```swift
33+
try model.readCheckpoint(from: directory, name: "LeNet")
34+
```
35+
This default implementation for model loading and saving will use a path-based naming
36+
scheme for each tensor in the model that is based on the names of the properties within the
37+
model structs. For example, the weights and biases within the first convolution in
38+
[the LeNet-5 model](https://github.com/tensorflow/swift-models/blob/main/Models/ImageClassification/LeNet-5.swift#L26)
39+
will be saved with the names `conv1/filter` and `conv1/bias`, respectively. When loading,
40+
the checkpoint reader will search for tensors with these names.
41+
42+
## Customizing model loading and saving
43+
44+
If you want to have greater control over which tensors are saved and loaded, or the naming
45+
of those tensors, the `Checkpointable` protocol offers a few points of customization.
46+
47+
To ignore properties on certain types, you can provide an implementation of
48+
`ignoredTensorPaths` on your model that returns a Set of strings in the form of
49+
`Type.property`. For example, to ignore the `scale` property on every Attention layer, you
50+
could return `["Attention.scale"]`.
51+
52+
By default, a forward slash is used to separate each deeper level in a model. This can be
53+
customized by implementing `checkpointSeparator` on your model and providing a new
54+
string to use for this separator.
55+
56+
Finally, for the greatest degree of customization in tensor naming, you can implement
57+
`tensorNameMap` and provide a function that maps from the default string name generated
58+
for a tensor in the model to a desired string name in the checkpoint. Most commonly, this
59+
will be used to interoperate with checkpoints generated with other frameworks, each of which
60+
have their own naming conventions and model structures. A custom mapping function gives
61+
the greatest degree of customization for how these tensors are named.
62+
63+
Some standard helper functions are provided, like the default
64+
`CheckpointWriter.identityMap` (which simply uses the automatically generated tensor
65+
path name for checkpoints), or the `CheckpointWriter.lookupMap(table:)` function,
66+
which can build a mapping from a dictionary.
67+
68+
For an example of how custom mapping can be accomplished, please see
69+
[the GPT-2 model](https://github.com/tensorflow/swift-models/blob/main/Models/Text/GPT2/CheckpointWriter.swift),
70+
which uses a mapping function to match the exact naming scheme used for OpenAI's
71+
checkpoints.
72+
73+
## The CheckpointReader and CheckpointWriter components
74+
75+
For checkpoint writing, the extension provided by the `Checkpointable` protocol
76+
uses reflection and keypaths to iterate over a model's properties and generate a dictionary
77+
that maps string tensor paths to Tensor values. This dictionary is provided to an underlying
78+
`CheckpointWriter`, along with a directory in which to write the checkpoint. That
79+
`CheckpointWriter` handles the task of generating the on-disk checkpoint from that
80+
dictionary.
81+
82+
The reverse of this process is reading, where a `CheckpointReader` is given the location of
83+
an on-disk checkpoint directory. It then reads from that checkpoint and forms a dictionary that
84+
maps the names of tensors within the checkpoint with their saved values. This dictionary is
85+
used to replace the current tensors in a model with the ones in this dictionary.
86+
87+
For both loading and saving, the `Checkpointable` protocol maps the string paths to tensors
88+
to corresponding on-disk tensor names using the above-described mapping function.
89+
90+
If the `Checkpointable` protocol lacks needed functionality, or more control is desired over
91+
the checkpoint loading and saving process, the `CheckpointReader` and
92+
`CheckpointWriter` classes can be used by themselves.
93+
94+
## The TensorFlow v2 checkpoint format
95+
96+
The TensorFlow v2 checkpoint format, as briefly described in
97+
[this header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/tensor_bundle/tensor_bundle.h),
98+
is the second generation format for TensorFlow model checkpoints. This second-generation
99+
format has been in use since late 2016, and has a number of improvements over the v1
100+
checkpoint format. TensorFlow SavedModels use v2 checkpoints within them to save model
101+
parameters.
102+
103+
A TensorFlow v2 checkpoint consists of a directory with a structure like the following:
104+
105+
```
106+
checkpoint/modelname.index
107+
checkpoint/modelname.data-00000-of-00002
108+
checkpoint/modelname.data-00001-of-00002
109+
```
110+
111+
where the first file stores the metadata for the checkpoint and the remaining files are binary
112+
shards holding the serialized parameters for the model.
113+
114+
The index metadata file contains the types, sizes, locations, and string names of all serialized
115+
tensors contained in the shards. That index file is the most structurally complex part of the
116+
checkpoint, and is based on `tensorflow::table`, which is itself based on SSTable / LevelDB.
117+
This index file is composed of a series of key-value pairs, where the keys are strings and the
118+
values are protocol buffers. The strings are sorted and prefix-compressed. For example: if
119+
the first entry is `conv1/weight` and next `conv1/bias`, the second entry only uses the
120+
`bias` part.
121+
122+
This overall index file is sometimes compressed using
123+
[Snappy compression](https://github.com/google/snappy). The
124+
`SnappyDecompression.swift` file provides a native Swift implementation of Snappy
125+
decompression from a compressed Data instance.
126+
127+
The index header metadata and tensor metadata are encoded as protocol buffers and
128+
encoded / decoded directly via [Swift Protobuf](https://github.com/apple/swift-protobuf).
129+
130+
The `CheckpointIndexReader` and `CheckpointIndexWriter` classes handle loading
131+
and saving these index files as part of the overarching `CheckpointReader` and
132+
`CheckpointWriter` classes. The latter use the index files as basis for determining what to
133+
read from and write to the structurally simpler binary shards that contain the tensor data.

0 commit comments

Comments
 (0)