Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 8d2465e

Browse files
adarobcopybara-github
authored andcommitted
Add reduce_dims to avoid warning.
PiperOrigin-RevId: 281172010
1 parent 1bedb03 commit 8d2465e

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

mesh_tensorflow/transformer/moe.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,13 @@ def transformer_moe_layer_v1(
368368
# Now feed the expert inputs through the experts.
369369
h = mtf.layers.dense(
370370
expert_inputs, hidden_dim, expert_dims=[experts_dim],
371+
reduced_dims=expert_inputs.shape.dims[-1:],
371372
activation=activation, use_bias=False,
372373
variable_dtype=variable_dtype, name="wi")
373374

374375
expert_output = mtf.layers.dense(
375376
h, output_dim, expert_dims=[experts_dim], use_bias=False,
376-
variable_dtype=variable_dtype,
377+
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
377378
name="wo")
378379

379380
expert_output = mtf.reshape(
@@ -630,10 +631,12 @@ def transformer_moe_layer_v2(
630631

631632
hidden_output = mtf.layers.dense(
632633
expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
634+
reduced_dims=expert_inputs_y.shape.dims[-1:],
633635
activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype,
634636
name="wi")
635637
expert_output = mtf.layers.dense(
636638
hidden_output, output_dim, expert_dims=[y0, x1],
639+
reduced_dims=hidden_output.shape.dims[-1:],
637640
use_bias=False, variable_dtype=variable_dtype,
638641
name="wo")
639642

mesh_tensorflow/transformer/transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _call_internal(self, context, inputs, targets=None):
591591
logits = mtf.layers.dense(
592592
x, self.output_vocab_dim, use_bias=False,
593593
variable_dtype=context.variable_dtype,
594+
reduced_dims=x.shape.dims[-1:],
594595
name="logits")
595596
if targets is not None and context.losses is not None:
596597
context.losses.append(

mesh_tensorflow/transformer/transformer_layers.py

+2
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ def call(self, context, x, losses=None):
5656
h = mtf.layers.dense(x, hidden_channels,
5757
use_bias=False, activation=mtf.relu,
5858
variable_dtype=context.variable_dtype,
59+
reduced_dims=x.shape.dims[-1:],
5960
name="wi", expert_dims=expert_dims)
6061
if context.train and self.dropout_rate != 0.0:
6162
h = mtf.dropout(h, 1.0 - self.dropout_rate,
6263
noise_shape=h.shape - context.length_dim)
6364
return mtf.layers.dense(h, io_channels, use_bias=False, activation=None,
6465
variable_dtype=context.variable_dtype,
66+
reduced_dims=h.shape.dims[-1:],
6567
name="wo", expert_dims=expert_dims)
6668

6769

0 commit comments

Comments
 (0)