Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 9e69fb9

Browse files
nshazeercopybara-github
authored andcommittedOct 3, 2019
Update mesh to be Tensorflow 2.0 compatible.
PiperOrigin-RevId: 272767933
1 parent 6979713 commit 9e69fb9

33 files changed

+175
-128
lines changed
 

‎examples/mnist.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
import mesh_tensorflow as mtf
2626
import mnist_dataset as dataset # local file import
27-
import tensorflow as tf
27+
import tensorflow.compat.v1 as tf
28+
29+
tf.disable_v2_behavior()
2830

2931

3032
tf.flags.DEFINE_string("data_dir", "/tmp/mnist_data",

‎examples/mnist_dataset.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939

4040
import numpy as np
4141
from six.moves import urllib
42-
import tensorflow as tf
42+
import tensorflow.compat.v1 as tf
43+
44+
tf.disable_v2_behavior()
4345

4446

4547
def read32(bytestream):

‎examples/toy_model_tpu.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import mesh_tensorflow as mtf
2323
import numpy
24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
2525

2626
from tensorflow.contrib.tpu.python.tpu import tpu_config
2727
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
@@ -30,6 +30,8 @@
3030
from tensorflow.python.platform import tf_logging as logging
3131
from tensorflow_estimator.python.estimator import estimator as estimator_lib
3232

33+
tf.disable_v2_behavior()
34+
3335
FLAGS = flags.FLAGS
3436

3537
tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.')
@@ -89,7 +91,7 @@ def __call__(self, params):
8991
"""Input function which provides a single batch for train or eval."""
9092
# Retrieves the batch size for the current shard. The # of shards is
9193
# computed according to the input pipeline deployment. See
92-
# `tf.contrib.tpu.RunConfig` for details.
94+
# `tf.estimator.tpu.RunConfig` for details.
9395
batch_size = params['batch_size']
9496
logging.info('call ToyModelInput() with batch size {}'.format(batch_size))
9597

@@ -242,7 +244,7 @@ def metric_fn(tf_logits):
242244

243245
def run_toy_model_tpu():
244246
"""Run a toy model on TPU."""
245-
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
247+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
246248
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
247249

248250
iterations_per_loop = FLAGS.iterations

‎mesh_tensorflow/auto_mtf/api.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
import mesh_tensorflow as mtf
4040
from mesh_tensorflow.auto_mtf import layout_optimizer
4141
from mesh_tensorflow.auto_mtf import memory_estimator
42-
import tensorflow as tf
42+
import tensorflow.compat.v1 as tf
43+
44+
tf.disable_v2_behavior()
4345

4446

4547
def layout(mtf_graph, mesh_shape, mtf_outputs=()):

‎mesh_tensorflow/auto_mtf/api_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import mesh_tensorflow as mtf
2323
import mesh_tensorflow.auto_mtf # pylint: disable=unused-import
2424
import mesh_tensorflow.auto_mtf.api
25-
import tensorflow as tf
25+
import tensorflow.compat.v1 as tf
26+
27+
tf.disable_v2_behavior()
2628

2729

2830
class LayoutTest(tf.test.TestCase):

‎mesh_tensorflow/auto_mtf/graph_interface.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
import collections
3030
import math
3131
import mesh_tensorflow as mtf
32-
import tensorflow as tf
32+
import tensorflow.compat.v1 as tf
3333
from tensorflow.core.framework import cost_graph_pb2
3434

35+
tf.disable_v2_behavior()
36+
3537

3638
class GraphInterface(object):
3739
"""tf.Graph & mtf.Graph common representation which produces a CostGraphDef.

‎mesh_tensorflow/auto_mtf/graph_interface_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121

2222
import mesh_tensorflow as mtf
2323
from mesh_tensorflow.auto_mtf import graph_interface
24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
2525
from tensorflow.core.framework import cost_graph_pb2
2626
from tensorflow.core.framework import tensor_shape_pb2
2727
from tensorflow.core.framework import types_pb2
2828

29+
tf.disable_v2_behavior()
30+
2931

3032
class GraphInterfaceTest(tf.test.TestCase):
3133

‎mesh_tensorflow/auto_mtf/layout_optimizer_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from mesh_tensorflow.auto_mtf import layout_optimizer
2424
from mesh_tensorflow.auto_mtf import memory_estimator
2525
import six
26-
import tensorflow as tf
26+
import tensorflow.compat.v1 as tf
27+
28+
tf.disable_v2_behavior()
2729

2830

2931
class VariableNamesTest(tf.test.TestCase):

‎mesh_tensorflow/auto_mtf/memory_estimator_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import mesh_tensorflow as mtf
2323
from mesh_tensorflow.auto_mtf import memory_estimator
24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
25+
26+
tf.disable_v2_behavior()
2527

2628

2729
class MemoryEstimatorTest(tf.test.TestCase):

‎mesh_tensorflow/auto_mtf/scheduler_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import mesh_tensorflow as mtf
2424
from mesh_tensorflow.auto_mtf import graph_interface
2525
from mesh_tensorflow.auto_mtf import scheduler
26-
import tensorflow as tf
26+
import tensorflow.compat.v1 as tf
27+
28+
tf.disable_v2_behavior()
2729

2830

2931
class SchedulerTest(parameterized.TestCase, tf.test.TestCase):

‎mesh_tensorflow/auto_mtf/valid_layouts_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import mesh_tensorflow as mtf
2323
from mesh_tensorflow.auto_mtf import valid_layouts
24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
25+
26+
tf.disable_v2_behavior()
2527

2628

2729
class LayoutValidatorTest(tf.test.TestCase):

‎mesh_tensorflow/beam_search.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import gin
2323
from mesh_tensorflow import ops_with_redefined_builtins as mtf
24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
25+
26+
tf.disable_v2_behavior()
2527

2628
# Assuming EOS_ID is 1
2729
EOS_ID = 1

‎mesh_tensorflow/import_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from __future__ import print_function
2121

2222
import mesh_tensorflow as mtf # pylint: disable=unused-import
23-
import tensorflow as tf
23+
import tensorflow.compat.v1 as tf
24+
25+
tf.disable_v2_behavior()
2426

2527

2628
class ImportTest(tf.test.TestCase):

‎mesh_tensorflow/layers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
from mesh_tensorflow import ops_with_redefined_builtins as mtf
2323

24-
import tensorflow as tf
24+
import tensorflow.compat.v1 as tf
25+
26+
tf.disable_v2_behavior()
2527

2628

2729
def dense(x, output_dim, reduced_dims=None, expert_dims=None,

‎mesh_tensorflow/layers_test.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
import numpy as np
2626

2727
from tensor2tensor.layers import common_layers
28+
from tensor2tensor.utils import test_utils
2829

29-
import tensorflow as tf
30+
import tensorflow.compat.v1 as tf
31+
32+
tf.disable_v2_behavior()
33+
tf.enable_eager_execution()
3034

3135

3236
class LayersTest(parameterized.TestCase, tf.test.TestCase):
@@ -69,7 +73,7 @@ def testDense(self, units, use_bias):
6973

7074
self.assertEqual(actual.shape, expected.shape)
7175

72-
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
76+
@test_utils.run_in_graph_and_eager_modes()
7377
def testLayerNorm(self):
7478
batch = 2
7579
channels = 3
@@ -98,7 +102,7 @@ def testLayerNorm(self):
98102

99103
self.assertEqual(actual.shape, expected.shape)
100104

101-
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
105+
@test_utils.run_in_graph_and_eager_modes()
102106
def testBatchNorm(self):
103107
batch = 2
104108
channels = 3
@@ -138,7 +142,7 @@ def testBatchNorm(self):
138142
self.assertAllClose(actual_0, expected)
139143
self.assertAllClose(actual_1, expected)
140144

141-
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
145+
@test_utils.run_in_graph_and_eager_modes()
142146
def testWeightsNonzero(self):
143147
inputs = tf.constant([[3, 1, 0], [1, 0, 0]])
144148

@@ -162,7 +166,7 @@ def testWeightsNonzero(self):
162166

163167
self.assertAllEqual(actual, expected)
164168

165-
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
169+
@test_utils.run_in_graph_and_eager_modes()
166170
def testDenseReluDense(self):
167171
batch = 2
168172
channels = 3

‎mesh_tensorflow/ops.py

+17-48
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
import six
3030
from six.moves import xrange # pylint: disable=redefined-builtin
3131

32-
import tensorflow as tf
32+
import tensorflow.compat.v1 as tf
3333

3434
# pylint: disable=g-direct-tensorflow-import
3535
from tensorflow.python.ops.gen_nn_ops import conv3d_backprop_input_v2
3636
from tensorflow.python.ops.nn_ops import conv3d_backprop_filter_v2
3737

38+
tf.disable_v2_behavior()
3839

3940
Dimension = collections.namedtuple("Dimension", ["name", "size"])
4041

@@ -1218,6 +1219,8 @@ def einsum(self, equation, *slices):
12181219
Args:
12191220
equation: a string
12201221
*slices: a list of tf.Tensor
1222+
Returns:
1223+
a Tensor
12211224
"""
12221225
return tf.einsum(equation, *slices)
12231226

@@ -4973,6 +4976,8 @@ def gather(weights, indices, dim, output_shape=None):
49734976
dim = convert_to_dimension(dim)
49744977
output_shape = convert_to_shape(output_shape)
49754978
if not isinstance(indices, Tensor):
4979+
# TODO(noam): when `indices` is an integer, gather can be implemented
4980+
# more directly with mtf_slice() and reshape()
49764981
indices = constant(weights.mesh, indices, dtype=tf.int32)
49774982
if weights.dtype == tf.bool:
49784983
return cast(gather(to_float(weights), indices, dim, output_shape), tf.bool)
@@ -6087,60 +6092,24 @@ def body_fn(microbatch_num):
60876092
return combined_grads, combined_outputs
60886093

60896094

6090-
class NthSmallestElementOperation(Operation):
6091-
"""Reduce out last dimension - output is nth-smallest (or largest) element.
6092-
6093-
TODO(noam): make n a tensor instead of an integer
6094-
"""
6095-
6096-
def __init__(self, x, n, reverse, name=None):
6097-
super(NthSmallestElementOperation, self).__init__(
6098-
[x], name=name or "nth_element")
6099-
reduced_dim = x.shape.dims[-1]
6100-
output_shape = x.shape - reduced_dim
6101-
self._outputs = [Tensor(self, output_shape, x.dtype)]
6102-
self._n = n
6103-
self._initialize_splittable_and_unsplittable_dims(
6104-
"splittable", [reduced_dim])
6105-
self._reverse = reverse
6106-
6107-
def gradient(self, grad_ys):
6108-
raise NotImplementedError("TODO(noam): implement gradient")
6109-
6110-
def lower(self, lowering):
6111-
mesh_impl = lowering.mesh_impl(self)
6112-
def slicewise_fn(x):
6113-
return tf.contrib.nn.nth_element(x, self._n, reverse=self._reverse)
6114-
y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]])
6115-
lowering.set_tensor_lowering(self.outputs[0], y)
6116-
6117-
6118-
def nth_smallest_element(x, n, reduced_dim=None, reverse=False, name=None):
6119-
"""Nth-smallest (or largest) reduction on specified axis.
6095+
def nth_largest_element(x, n, reduced_dim, name=None):
6096+
"""Nth-largest reduction on specified axis.
61206097
61216098
Note that n is zero-indexed.
61226099
6123-
In the case that reduced_dim is split, we do something inefficient:
6124-
shift data around so that it is replicated and do the computation
6125-
everywhere.
6126-
61276100
Args:
61286101
x: a Tensor
61296102
n: an integer
6130-
reduced_dim: an optional Dimension - defaults to the last dimension of n
6131-
reverse: a boolean
6103+
reduced_dim: a Dimension
61326104
name: an optional string
61336105
Returns:
61346106
a Tensor
61356107
"""
6136-
if reduced_dim is None:
6137-
reduced_dim = x.shape.dims[-1]
6138-
# remove the reduced dimension from the shape and insert it at the end
6139-
x = transpose(x, x.shape - reduced_dim + reduced_dim)
6140-
# Since the NthSmallestElementOperation does not know how to reduce over a
6141-
# split dimension, we rename the reduced dimension so that we ensure that it
6142-
# is not split. This may cause the tensor to get all-concatenated, causing
6143-
# redundant computation.
6144-
unsplit_dim = Dimension("_unsplit", reduced_dim.size)
6145-
x = replace_dimensions(x, reduced_dim, unsplit_dim)
6146-
return NthSmallestElementOperation(x, n, reverse, name).outputs[0]
6108+
# Compute the top k=n+1 values, then take the last one.
6109+
k_dim = Dimension("_top_k_", n + 1)
6110+
values, _ = top_k(x, reduced_dim=reduced_dim, k_dim=k_dim, name=name)
6111+
return gather(values, n, k_dim)
6112+
6113+
6114+
def nth_smallest_element(x, n, reduced_dim, name=None):
6115+
return -nth_largest_element(-x, n, reduced_dim, name=name)

0 commit comments

Comments
 (0)
This repository has been archived.