|
29 | 29 | import six
|
30 | 30 | from six.moves import xrange # pylint: disable=redefined-builtin
|
31 | 31 |
|
32 |
| -import tensorflow as tf |
| 32 | +import tensorflow.compat.v1 as tf |
33 | 33 |
|
34 | 34 | # pylint: disable=g-direct-tensorflow-import
|
35 | 35 | from tensorflow.python.ops.gen_nn_ops import conv3d_backprop_input_v2
|
36 | 36 | from tensorflow.python.ops.nn_ops import conv3d_backprop_filter_v2
|
37 | 37 |
|
| 38 | +tf.disable_v2_behavior() |
38 | 39 |
|
39 | 40 | Dimension = collections.namedtuple("Dimension", ["name", "size"])
|
40 | 41 |
|
@@ -1218,6 +1219,8 @@ def einsum(self, equation, *slices):
|
1218 | 1219 | Args:
|
1219 | 1220 | equation: a string
|
1220 | 1221 | *slices: a list of tf.Tensor
|
| 1222 | + Returns: |
| 1223 | + a Tensor |
1221 | 1224 | """
|
1222 | 1225 | return tf.einsum(equation, *slices)
|
1223 | 1226 |
|
@@ -4973,6 +4976,8 @@ def gather(weights, indices, dim, output_shape=None):
|
4973 | 4976 | dim = convert_to_dimension(dim)
|
4974 | 4977 | output_shape = convert_to_shape(output_shape)
|
4975 | 4978 | if not isinstance(indices, Tensor):
|
| 4979 | + # TODO(noam): when `indices` is an integer, gather can be implemented |
| 4980 | + # more directly with mtf_slice() and reshape() |
4976 | 4981 | indices = constant(weights.mesh, indices, dtype=tf.int32)
|
4977 | 4982 | if weights.dtype == tf.bool:
|
4978 | 4983 | return cast(gather(to_float(weights), indices, dim, output_shape), tf.bool)
|
@@ -6087,60 +6092,24 @@ def body_fn(microbatch_num):
|
6087 | 6092 | return combined_grads, combined_outputs
|
6088 | 6093 |
|
6089 | 6094 |
|
6090 |
| -class NthSmallestElementOperation(Operation): |
6091 |
| - """Reduce out last dimension - output is nth-smallest (or largest) element. |
6092 |
| -
|
6093 |
| - TODO(noam): make n a tensor instead of an integer |
6094 |
| - """ |
6095 |
| - |
6096 |
| - def __init__(self, x, n, reverse, name=None): |
6097 |
| - super(NthSmallestElementOperation, self).__init__( |
6098 |
| - [x], name=name or "nth_element") |
6099 |
| - reduced_dim = x.shape.dims[-1] |
6100 |
| - output_shape = x.shape - reduced_dim |
6101 |
| - self._outputs = [Tensor(self, output_shape, x.dtype)] |
6102 |
| - self._n = n |
6103 |
| - self._initialize_splittable_and_unsplittable_dims( |
6104 |
| - "splittable", [reduced_dim]) |
6105 |
| - self._reverse = reverse |
6106 |
| - |
6107 |
| - def gradient(self, grad_ys): |
6108 |
| - raise NotImplementedError("TODO(noam): implement gradient") |
6109 |
| - |
6110 |
| - def lower(self, lowering): |
6111 |
| - mesh_impl = lowering.mesh_impl(self) |
6112 |
| - def slicewise_fn(x): |
6113 |
| - return tf.contrib.nn.nth_element(x, self._n, reverse=self._reverse) |
6114 |
| - y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]]) |
6115 |
| - lowering.set_tensor_lowering(self.outputs[0], y) |
6116 |
| - |
6117 |
| - |
6118 |
| -def nth_smallest_element(x, n, reduced_dim=None, reverse=False, name=None): |
6119 |
| - """Nth-smallest (or largest) reduction on specified axis. |
| 6095 | +def nth_largest_element(x, n, reduced_dim, name=None): |
| 6096 | + """Nth-largest reduction on specified axis. |
6120 | 6097 |
|
6121 | 6098 | Note that n is zero-indexed.
|
6122 | 6099 |
|
6123 |
| - In the case that reduced_dim is split, we do something inefficient: |
6124 |
| - shift data around so that it is replicated and do the computation |
6125 |
| - everywhere. |
6126 |
| -
|
6127 | 6100 | Args:
|
6128 | 6101 | x: a Tensor
|
6129 | 6102 | n: an integer
|
6130 |
| - reduced_dim: an optional Dimension - defaults to the last dimension of n |
6131 |
| - reverse: a boolean |
| 6103 | + reduced_dim: a Dimension |
6132 | 6104 | name: an optional string
|
6133 | 6105 | Returns:
|
6134 | 6106 | a Tensor
|
6135 | 6107 | """
|
6136 |
| - if reduced_dim is None: |
6137 |
| - reduced_dim = x.shape.dims[-1] |
6138 |
| - # remove the reduced dimension from the shape and insert it at the end |
6139 |
| - x = transpose(x, x.shape - reduced_dim + reduced_dim) |
6140 |
| - # Since the NthSmallestElementOperation does not know how to reduce over a |
6141 |
| - # split dimension, we rename the reduced dimension so that we ensure that it |
6142 |
| - # is not split. This may cause the tensor to get all-concatenated, causing |
6143 |
| - # redundant computation. |
6144 |
| - unsplit_dim = Dimension("_unsplit", reduced_dim.size) |
6145 |
| - x = replace_dimensions(x, reduced_dim, unsplit_dim) |
6146 |
| - return NthSmallestElementOperation(x, n, reverse, name).outputs[0] |
| 6108 | + # Compute the top k=n+1 values, then take the last one. |
| 6109 | + k_dim = Dimension("_top_k_", n + 1) |
| 6110 | + values, _ = top_k(x, reduced_dim=reduced_dim, k_dim=k_dim, name=name) |
| 6111 | + return gather(values, n, k_dim) |
| 6112 | + |
| 6113 | + |
| 6114 | +def nth_smallest_element(x, n, reduced_dim, name=None): |
| 6115 | + return -nth_largest_element(-x, n, reduced_dim, name=name) |
0 commit comments