|
52 | 52 | "bitwise_or",
|
53 | 53 | "bmm",
|
54 | 54 | "broadcast_tensors",
|
| 55 | + "broadcast_to", |
55 | 56 | "bucketize",
|
56 | 57 | "cat",
|
57 | 58 | "cdist",
|
@@ -965,6 +966,27 @@ def expand(g: jit_utils.GraphContext, self, size, implicit):
|
965 | 966 | return g.op("Expand", self, size)
|
966 | 967 |
|
967 | 968 |
|
| 969 | +@_onnx_symbolic("aten::broadcast_to") |
| 970 | +@symbolic_helper.quantized_args(True) |
| 971 | +@_beartype.beartype |
| 972 | +def broadcast_to(g: jit_utils.GraphContext, self, size): |
| 973 | + size = symbolic_helper._maybe_get_const(size, "is") |
| 974 | + if not symbolic_helper._is_value(size): |
| 975 | + size = g.op("Constant", value_t=torch.LongTensor(size)) |
| 976 | + elif symbolic_helper._is_packed_list(size): |
| 977 | + # Expand with -1 dim value means dim is unchanged. |
| 978 | + # Since onnx::expand supports two-way broadcasting, |
| 979 | + # -1 dim value can be exported to onnx as 1 |
| 980 | + size = symbolic_helper._reshape_helper( |
| 981 | + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) |
| 982 | + ) |
| 983 | + dtype = _type_utils.JitScalarType.INT64 |
| 984 | + ones = ones_like(g, size, dtype) |
| 985 | + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) |
| 986 | + size = where(g, g.op("Equal", size, neg_ones), ones, size) |
| 987 | + return g.op("Expand", self, size) |
| 988 | + |
| 989 | + |
968 | 990 | @_onnx_symbolic("aten::expand_as")
|
969 | 991 | @symbolic_helper.quantized_args(True, True)
|
970 | 992 | @_beartype.beartype
|
|
0 commit comments