This repository was archived by the owner on Jan 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 256
/
Copy pathtoy_model_tpu.py
293 lines (253 loc) · 11.2 KB
/
toy_model_tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# coding=utf-8
# Copyright 2023 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A toy model using Mesh TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mesh_tensorflow as mtf
import numpy
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow_estimator.python.estimator import estimator as estimator_lib
from tensorflow_estimator.python.estimator.tpu import tpu_config # pylint: disable=g-deprecated-tf-checker
from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-deprecated-tf-checker
FLAGS = flags.FLAGS
tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.')
tf.flags.DEFINE_integer('io_size', 16, 'Number of channels per feature.')
tf.flags.DEFINE_integer('hidden_size', 16, 'Size of each hidden layer.')
tf.flags.DEFINE_integer('num_hidden_layers', 1, 'Number of layers.')
tf.flags.DEFINE_string('master_dtype', 'bfloat16', 'dtype for master vars.')
tf.flags.DEFINE_string('slice_dtype', 'float32', 'dtype for slice vars.')
tf.flags.DEFINE_string('activation_dtype', 'float32', 'dtype for activations.')
tf.flags.DEFINE_string('optimizer', 'SGD', 'optimizer (SGD or Adafactor).')
tf.flags.DEFINE_float('lr', 1e-4, 'Learning rate.')
tf.flags.DEFINE_string('mesh_shape', 'all:8', 'mesh shape')
tf.flags.DEFINE_string('layout', 'hidden_odd:all', 'layout rules')
tf.flags.DEFINE_integer('iterations', 100,
'Number of iterations per training loop.')
tf.flags.DEFINE_integer('step_with_nan', -1,
'If >= 0, a NaN tensor is added in forward pass.')
tf.flags.DEFINE_integer('train_steps', 10000, 'max steps')
tf.flags.DEFINE_integer('steps_per_checkpoint', 200, 'steps_per_checkpoint')
tf.flags.DEFINE_string(
'model_dir',
default='',
help='The directory where the model will be stored.')
tf.flags.DEFINE_bool('use_tpu', True, 'use TPU')
# Cloud TPU Cluster Resolvers
tf.flags.DEFINE_string(
'tpu',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')
tf.flags.DEFINE_string(
'gcp_project',
default=None,
help='Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
tf.flags.DEFINE_string(
'tpu_zone',
default=None,
help='GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
class ToyModelInput(object):
"""Wrapper class that acts as the input_fn to TPUEstimator."""
def __init__(self):
self._num_examples = 10000 # 10k
self._images = numpy.random.uniform(
0, 1.0, [self._num_examples, FLAGS.io_size]).astype(numpy.float32)
self._labels = self._images
logging.info('init ToyModelInput()')
def __call__(self, params):
"""Input function which provides a single batch for train or eval."""
# Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See
# `tf.estimator.tpu.RunConfig` for details.
batch_size = params['batch_size']
logging.info('call ToyModelInput() with batch size {}'.format(batch_size))
ds = Dataset.from_tensor_slices((self._images, self._labels)).repeat()
dataset = ds.batch(batch_size, drop_remainder=True).prefetch(2)
return dataset
def toy_model(features, mesh):
"""A toy model implemented by mesh tensorlfow."""
batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
io_dim = mtf.Dimension('io', FLAGS.io_size)
master_dtype = tf.as_dtype(FLAGS.master_dtype)
slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
x = mtf.cast(x, activation_dtype)
h = x
for lnum in range(1, FLAGS.num_hidden_layers + 2):
if lnum + 1 == FLAGS.num_hidden_layers + 2:
# output layer
dim = io_dim
elif lnum % 2 == 0:
dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
else:
dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
h = mtf.layers.dense(
h, dim,
use_bias=False,
master_dtype=master_dtype,
slice_dtype=slice_dtype,
name='layer_%d' % lnum)
y = h
g = tf.train.get_global_step()
if FLAGS.step_with_nan >= 0:
# Trigger NaN in the forward pass, this is used for testing whether
# MeshTensorFlow can handle occasional NaN value.
y += mtf.import_tf_tensor(
mesh,
tf.divide(
0.0,
tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0., lambda: 1.)),
mtf.Shape([]))
loss = mtf.reduce_mean(mtf.square(y - x))
return y, loss
def model_fn(features, labels, mode, params):
"""A model is called by TpuEstimator."""
del labels
global_step = tf.train.get_global_step()
graph = mtf.Graph()
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
if FLAGS.use_tpu:
ctx = params['context']
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
tf.logging.info('device_list = %s' % device_list,)
# TODO(ylc): Better estimation of replica cache size?
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries.
worker0_mem = replica_cache_size * ctx.num_replicas
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
devices_memeory_usage)
mesh_devices = [''] * mesh_shape.size
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
else:
var_placer = None
mesh_devices = [''] * mesh_shape.size
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, mesh_devices)
mesh = mtf.Mesh(graph, 'my_mesh', var_placer)
with mtf.utils.outside_all_rewrites():
logits, loss = toy_model(features, mesh)
# TRAIN mode
if mode == tf_estimator.ModeKeys.TRAIN:
var_grads = mtf.gradients([loss],
[v.outputs[0] for v in graph.trainable_variables])
if FLAGS.optimizer == 'Adafactor':
optimizer = mtf.optimize.AdafactorOptimizer()
else:
assert FLAGS.optimizer == 'SGD'
optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr)
update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
else:
# for now, we can only export fully-replicated tensors.
fully_replicated_logits = mtf.anonymize(logits)
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
if mode == tf_estimator.ModeKeys.TRAIN:
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1))
tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
train_op = tf.group(tf_update_ops)
else:
tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)
with mtf.utils.outside_all_rewrites():
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
if mode == tf_estimator.ModeKeys.TRAIN:
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=10,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
FLAGS.model_dir,
save_steps=1000,
saver=saver,
listeners=[saver_listener])
return tpu_estimator.TPUEstimatorSpec(
tf_estimator.ModeKeys.TRAIN,
loss=tf_loss,
train_op=train_op,
training_hooks=[restore_hook, saver_hook])
elif mode == tf_estimator.ModeKeys.EVAL:
def metric_fn(tf_logits):
mean_logits = tf.metrics.mean(tf_logits)
return {'mean_logits': mean_logits}
eval_metrics = (metric_fn, [tf_logits])
return tpu_estimator.TPUEstimatorSpec(
tf_estimator.ModeKeys.EVAL,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics)
def run_toy_model_tpu():
"""Run a toy model on TPU."""
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
iterations_per_loop = FLAGS.iterations
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
config = tpu_config.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
save_checkpoints_steps=None, # Disable the default saver
save_checkpoints_secs=None, # Disable the default saver
log_step_count_steps=iterations_per_loop,
save_summary_steps=iterations_per_loop,
tpu_config=tpu_config.TPUConfig(
num_shards=mesh_shape.size,
iterations_per_loop=iterations_per_loop,
num_cores_per_replica=1,
per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
classifier = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=config,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size)
current_step = estimator_lib._load_global_step_from_checkpoint_dir(FLAGS.model_dir) # pylint: disable=protected-access,line-too-long
logging.info('Current step %d', current_step)
if FLAGS.steps_per_checkpoint == 0:
classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps)
return
while current_step < FLAGS.train_steps:
next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
FLAGS.train_steps)
classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
current_step = next_checkpoint
logging.info('Starting to evaluate.')
eval_results = classifier.evaluate(
input_fn=ToyModelInput(),
steps=156) # since we have 10000 examples and batch_size = 64 per host
logging.info('Eval results: %s', eval_results)
def main(_):
run_toy_model_tpu()
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()