Skip to content

Commit 0df691d

Browse files
titaiwangmspytorchmergebot
authored andcommitted
[ONNX] Support aten::broadcast_to (pytorch#101833)
Support aten::broadcast as the way we support on aten::expand. Fix pytorch#92678 pytorch#101768 Pull Request resolved: pytorch#101833 Approved by: https://github.com/thiagocrepaldi
1 parent 113b670 commit 0df691d

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

test/onnx/test_op_consistency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@
5656
[
5757
"atan",
5858
"atan2",
59+
"broadcast_to",
5960
"ceil",
61+
"expand",
6062
"flatten",
6163
"logical_not",
6264
"nn.functional.scaled_dot_product_attention",

torch/onnx/symbolic_opset9.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"bitwise_or",
5353
"bmm",
5454
"broadcast_tensors",
55+
"broadcast_to",
5556
"bucketize",
5657
"cat",
5758
"cdist",
@@ -965,6 +966,27 @@ def expand(g: jit_utils.GraphContext, self, size, implicit):
965966
return g.op("Expand", self, size)
966967

967968

969+
@_onnx_symbolic("aten::broadcast_to")
970+
@symbolic_helper.quantized_args(True)
971+
@_beartype.beartype
972+
def broadcast_to(g: jit_utils.GraphContext, self, size):
973+
size = symbolic_helper._maybe_get_const(size, "is")
974+
if not symbolic_helper._is_value(size):
975+
size = g.op("Constant", value_t=torch.LongTensor(size))
976+
elif symbolic_helper._is_packed_list(size):
977+
# Expand with -1 dim value means dim is unchanged.
978+
# Since onnx::expand supports two-way broadcasting,
979+
# -1 dim value can be exported to onnx as 1
980+
size = symbolic_helper._reshape_helper(
981+
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
982+
)
983+
dtype = _type_utils.JitScalarType.INT64
984+
ones = ones_like(g, size, dtype)
985+
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
986+
size = where(g, g.op("Equal", size, neg_ones), ones, size)
987+
return g.op("Expand", self, size)
988+
989+
968990
@_onnx_symbolic("aten::expand_as")
969991
@symbolic_helper.quantized_args(True, True)
970992
@_beartype.beartype

0 commit comments

Comments
 (0)