Skip to content

Commit 0e0a94a

Browse files
saberkuntensorflower-gardener
authored andcommitted
Remove compute_output_shape.
Keras: "manual" shape inference is only required if the layer is dynamic (otherwise we use TF's static shape inference capabilities) PiperOrigin-RevId: 290821518
1 parent ac97f01 commit 0e0a94a

File tree

5 files changed

+0
-37
lines changed

5 files changed

+0
-37
lines changed

official/nlp/modeling/layers/attention.py

-8
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,6 @@ def __init__(self,
118118

119119
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
120120

121-
def compute_output_shape(self, input_shape):
122-
# TODO(momernick): validate tensor dimensions.
123-
from_tensor_shape = tf.TensorShape(input_shape[0])
124-
batch = from_tensor_shape[0]
125-
from_tensor_length = from_tensor_shape[1]
126-
return tf.TensorShape(
127-
(batch, from_tensor_length, self._num_heads, self._head_size))
128-
129121
def get_config(self):
130122
config = {
131123
"num_heads":

official/nlp/modeling/layers/dense_einsum.py

-12
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,6 @@ def build(self, input_shape):
143143
self._bias = None
144144
super(DenseEinsum, self).build(input_shape)
145145

146-
def compute_output_shape(self, input_shape):
147-
input_shape = tf.TensorShape(input_shape)
148-
input_shape = input_shape.with_rank_at_least(self._num_summed_dimensions +
149-
1)
150-
for i in range(self._num_summed_dimensions):
151-
if tf.dimension_value(input_shape[-1 * i]) is None:
152-
raise ValueError(
153-
"The %s dimension of input_shape must be defined, but saw: %s" %
154-
(-1 * i, input_shape))
155-
return input_shape[:-1 * self._num_summed_dimensions].concatenate(
156-
self._units)
157-
158146
def get_config(self):
159147
config = {
160148
"output_shape":

official/nlp/modeling/layers/transformer.py

-7
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,6 @@ def build(self, input_shape):
158158

159159
super(Transformer, self).build(input_shape)
160160

161-
def compute_output_shape(self, input_shape):
162-
data_tensor_shape = tf.TensorShape(input_shape[0])
163-
batch = data_tensor_shape[0]
164-
sequence_length = data_tensor_shape[1]
165-
166-
return tf.TensorShape((batch, sequence_length, self._output_einsum_shape))
167-
168161
def get_config(self):
169162
config = {
170163
"num_attention_heads":

official/nlp/modeling/layers/transformer_scaffold.py

-7
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,6 @@ def build(self, input_shape):
175175

176176
super(TransformerScaffold, self).build(input_shape)
177177

178-
def compute_output_shape(self, input_shape):
179-
data_tensor_shape = tf.TensorShape(input_shape[0])
180-
batch = data_tensor_shape[0]
181-
sequence_length = data_tensor_shape[1]
182-
183-
return tf.TensorShape((batch, sequence_length, self._output_einsum_shape))
184-
185178
def get_config(self):
186179
config = {
187180
"attention_cls":

official/nlp/modeling/networks/masked_lm.py

-3
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,6 @@ def build(self, input_shape):
168168

169169
super(Bias, self).build(input_shape)
170170

171-
def compute_output_shape(self, input_shape):
172-
return input_shape
173-
174171
def get_config(self):
175172
config = {
176173
'activation': tf.keras.activations.serialize(self._activation),

0 commit comments

Comments
 (0)