Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5d6ebd7

Browse files
authored
Generate cumsum and cumprod from a spec. (#1071)
1 parent adaa050 commit 5d6ebd7

5 files changed

+186
-26
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

+67
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,72 @@
2626
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
2727
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
2828
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
29+
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
30+
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
31+
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
32+
#include "tensorflow/compiler/xla/client/lib/constants.h"
33+
34+
namespace at {
35+
xla::hash_t Hash(const c10::optional<at::ScalarType>& dtype) {
36+
return xla::util::Hash(swift_xla::OptionalOr<int>(dtype, -1));
37+
}
38+
}
39+
namespace swift_xla {
40+
void OpFieldToString(std::ostream& stream, const char* field_name, const c10::optional<at::ScalarType>& dtype) {
41+
if (dtype) stream << ", " << field_name << "=" << *dtype;
42+
}
43+
void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
44+
stream << ", " << field_name << "=" << value;
45+
}
46+
void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 value) {
47+
stream << ", " << field_name << "=" << value;
48+
}
49+
} // namespace swift_xla
50+
51+
namespace swift_xla {
52+
namespace ir {
53+
namespace ops {
54+
namespace {
55+
56+
xla::XlaOp LowerCumSum(xla::XlaOp input, xla::int64 dim,
57+
c10::optional<at::ScalarType> dtype, bool exclusive,
58+
bool reverse) {
59+
xla::XlaOp casted_input = CastToScalarType(input, dtype);
60+
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input);
61+
xla::XlaOp init = XlaHelpers::ScalarValue<float>(
62+
0, input_shape.element_type(), casted_input.builder());
63+
xla::XlaComputation reducer =
64+
XlaHelpers::CreateAddComputation(input_shape.element_type());
65+
return BuildCumulativeComputation(casted_input, dim, reducer, init, exclusive,
66+
reverse);
67+
}
68+
69+
xla::XlaOp LowerCumProd(xla::XlaOp input, xla::int64 dim,
70+
c10::optional<at::ScalarType> dtype, bool exclusive,
71+
bool reverse) {
72+
xla::XlaOp casted_input = CastToScalarType(input, dtype);
73+
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input);
74+
xla::XlaOp init =
75+
xla::One(casted_input.builder(), input_shape.element_type());
76+
xla::XlaComputation reducer =
77+
XlaHelpers::CreateMulComputation(input_shape.element_type());
78+
return BuildCumulativeComputation(casted_input, dim, reducer, init, exclusive,
79+
reverse);
80+
}
81+
82+
xla::Shape CumOpShapeFn(const Value& input, xla::int64 dim,
83+
c10::optional<at::ScalarType> dtype, bool exclusive,
84+
bool reverse) {
85+
if (dtype) {
86+
return xla::ShapeUtil::ChangeElementType(
87+
input.shape(), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr));
88+
}
89+
return input.shape();
90+
}
91+
92+
} // namespace
93+
} // namespace ops
94+
} // namespace ir
95+
} // namespace swift_xla
2996

3097
#include "xla_tensor_ops_wrapper_generated.cc.inc"

Sources/CX10/xla_tensor_ops_wrapper_generated.cc.inc

+93-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,88 @@
11
// Autogenerated by codegen.py. Do not modify.
2-
2+
33
namespace swift_xla {
44
namespace ir {
55
namespace ops {
66
namespace {
77

8+
class Cumprod : public Node {
9+
public:
10+
Cumprod(const Value& input, xla::int64 dim, c10::optional<at::ScalarType> dtype, bool exclusive, bool reverse)
11+
: Node(ir::OpKind(at::aten::cumprod),
12+
{input}, CumOpShapeFn(input, dim, dtype, exclusive, reverse),
13+
/*num_outputs=*/1, xla::util::MHash(dim, dtype, exclusive, reverse)),
14+
dim_(dim),
15+
dtype_(dtype),
16+
exclusive_(exclusive),
17+
reverse_(reverse) {}
18+
19+
NodePtr Clone(OpList operands) const override {
20+
return MakeNode<Cumprod>(
21+
operands.at(0), dim_, dtype_, exclusive_, reverse_);
22+
}
23+
24+
XlaOpVector Lower(LoweringContext* loctx) const override {
25+
xla::XlaOp result = LowerCumProd(
26+
loctx->GetOutputOp(operand(0)), dim_, dtype_, exclusive_, reverse_);
27+
return ReturnOp(result, loctx);
28+
}
29+
30+
std::string ToString() const override {
31+
std::stringstream ss;
32+
ss << Node::ToString();
33+
OpFieldToString(ss, "dim", dim_);
34+
OpFieldToString(ss, "dtype", dtype_);
35+
OpFieldToString(ss, "exclusive", exclusive_);
36+
OpFieldToString(ss, "reverse", reverse_);
37+
return ss.str();
38+
}
39+
40+
private:
41+
xla::int64 dim_;
42+
c10::optional<at::ScalarType> dtype_;
43+
bool exclusive_;
44+
bool reverse_;
45+
};
46+
47+
class Cumsum : public Node {
48+
public:
49+
Cumsum(const Value& input, xla::int64 dim, c10::optional<at::ScalarType> dtype, bool exclusive, bool reverse)
50+
: Node(ir::OpKind(at::aten::cumsum),
51+
{input}, CumOpShapeFn(input, dim, dtype, exclusive, reverse),
52+
/*num_outputs=*/1, xla::util::MHash(dim, dtype, exclusive, reverse)),
53+
dim_(dim),
54+
dtype_(dtype),
55+
exclusive_(exclusive),
56+
reverse_(reverse) {}
57+
58+
NodePtr Clone(OpList operands) const override {
59+
return MakeNode<Cumsum>(
60+
operands.at(0), dim_, dtype_, exclusive_, reverse_);
61+
}
62+
63+
XlaOpVector Lower(LoweringContext* loctx) const override {
64+
xla::XlaOp result = LowerCumSum(
65+
loctx->GetOutputOp(operand(0)), dim_, dtype_, exclusive_, reverse_);
66+
return ReturnOp(result, loctx);
67+
}
68+
69+
std::string ToString() const override {
70+
std::stringstream ss;
71+
ss << Node::ToString();
72+
OpFieldToString(ss, "dim", dim_);
73+
OpFieldToString(ss, "dtype", dtype_);
74+
OpFieldToString(ss, "exclusive", exclusive_);
75+
OpFieldToString(ss, "reverse", reverse_);
76+
return ss.str();
77+
}
78+
79+
private:
80+
xla::int64 dim_;
81+
c10::optional<at::ScalarType> dtype_;
82+
bool exclusive_;
83+
bool reverse_;
84+
};
85+
886
class LogSoftmaxBackward : public Node {
987
public:
1088
LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim)
@@ -26,7 +104,8 @@ class LogSoftmaxBackward : public Node {
26104

27105
std::string ToString() const override {
28106
std::stringstream ss;
29-
ss << Node::ToString() << ", dim=" << dim_;
107+
ss << Node::ToString();
108+
OpFieldToString(ss, "dim", dim_);
30109
return ss.str();
31110
}
32111

@@ -39,6 +118,18 @@ class LogSoftmaxBackward : public Node {
39118
} // namespace ir
40119
} // namespace swift_xla
41120

121+
OpaqueXLATensor* XLATensor_cumprod(OpaqueXLATensor* input, int64_t dim, Optional_XLAScalarType dtype, bool exclusive, bool reverse) {
122+
auto input_ir_value = input->GetIrValue();
123+
return new swift_xla::XLATensor(input->CreateFrom(
124+
swift_xla::ir::MakeNode<swift_xla::ir::ops::Cumprod>(input_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, input_ir_value.shape().rank()), dtype.value(), exclusive, reverse)));
125+
}
126+
127+
OpaqueXLATensor* XLATensor_cumsum(OpaqueXLATensor* input, int64_t dim, Optional_XLAScalarType dtype, bool exclusive, bool reverse) {
128+
auto input_ir_value = input->GetIrValue();
129+
return new swift_xla::XLATensor(input->CreateFrom(
130+
swift_xla::ir::MakeNode<swift_xla::ir::ops::Cumsum>(input_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, input_ir_value.shape().rank()), dtype.value(), exclusive, reverse)));
131+
}
132+
42133
OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output, OpaqueXLATensor* output, int64_t dim) {
43134
auto grad_output_ir_value = grad_output->GetIrValue();
44135
auto output_ir_value = output->GetIrValue();

Sources/CX10/xla_tensor_wrapper.cc

-12
Original file line numberDiff line numberDiff line change
@@ -294,18 +294,6 @@ OpaqueXLATensor* XLATensor_acos(OpaqueXLATensor* a) {
294294
OpaqueXLATensor* XLATensor_acosh(OpaqueXLATensor* a) {
295295
return new XLATensor(XLATensor::acosh(*a));
296296
}
297-
OpaqueXLATensor* XLATensor_cumprod(OpaqueXLATensor* a, int64_t dim,
298-
Optional_XLAScalarType dtype, bool exclusive,
299-
bool reverse) {
300-
return new XLATensor(
301-
XLATensor::cumprod(*a, dim, dtype.value(), exclusive, reverse));
302-
}
303-
OpaqueXLATensor* XLATensor_cumsum(OpaqueXLATensor* a, int64_t dim,
304-
Optional_XLAScalarType dtype, bool exclusive,
305-
bool reverse) {
306-
return new XLATensor(
307-
XLATensor::cumsum(*a, dim, dtype.value(), exclusive, reverse));
308-
}
309297
OpaqueXLATensor* XLATensor_add(OpaqueXLATensor* a, OpaqueXLATensor* b) {
310298
return new XLATensor(XLATensor::add(*a, *b));
311299
}

Sources/x10/swift_bindings/generate_ops.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
FLAGS = flags.FLAGS
99

1010
flags.DEFINE_string("def_file", None, "path to list of ops")
11-
flags.DEFINE_string("swift_out", None, "path for the generated swift file")
1211
flags.DEFINE_string("cc_output", None, "path for the generated cc file")
1312

1413
HEADER = """// Autogenerated by codegen.py. Do not modify.
@@ -21,12 +20,14 @@ def node_type_define(op):
2120
if arg[1] == "Tensor": tensor_args.append(arg)
2221
else: attr_args.append(arg)
2322
def format_pretty_print(arg):
24-
return f" << \", {arg[0]}=\" << {arg[0]}_"
23+
return f" OpFieldToString(ss, \"{arg[0]}\", {arg[0]}_);\n"
2524
def format_ctor_arg(arg):
2625
name, stype = arg
2726
if stype == "Tensor": return f"const Value& {name}"
2827
if stype == "Int64": return f"xla::int64 {name}"
29-
raise f"Problem: no such type: {stype}"
28+
if stype == "Bool": return f"bool {name}"
29+
if stype == "ScalarType?": return f"c10::optional<at::ScalarType> {name}"
30+
raise ValueError(f"Problem: no such type: {stype}")
3031
lower_arg_i = 0
3132
def format_lower_arg(arg):
3233
nonlocal lower_arg_i
@@ -35,8 +36,7 @@ def format_lower_arg(arg):
3536
i = lower_arg_i
3637
lower_arg_i += 1
3738
return "loctx->GetOutputOp(operand(" + str(i) + "))"
38-
if stype == "Int64": return f"{name}_"
39-
raise f"Problem: no such type: {stype}"
39+
return f"{name}_"
4040
clone_arg_i = 0
4141
def format_clone_arg(arg):
4242
nonlocal clone_arg_i
@@ -45,12 +45,13 @@ def format_clone_arg(arg):
4545
i = clone_arg_i
4646
clone_arg_i += 1
4747
return "operands.at(" + str(i) + ")"
48-
if stype == "Int64": return f"{name}_"
49-
raise f"Problem: no such type: {stype}"
48+
return f"{name}_"
5049
def format_attr_define(arg):
5150
name, stype = arg
5251
if stype == "Int64": return f" xla::int64 {name}_;\n"
53-
raise f"Problem: no such type: {stype}"
52+
if stype == "Bool": return f" bool {name}_;\n"
53+
if stype == "ScalarType?": return f" c10::optional<at::ScalarType> {name}_;\n"
54+
raise ValueError(f"Problem: no such type: {stype}")
5455
def format_attr_init(arg):
5556
return f",\n {arg[0]}_({arg[0]})"
5657
shape_fn = f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
@@ -84,8 +85,8 @@ class {op["op_node_name"]} : public Node {{
8485
8586
std::string ToString() const override {{
8687
std::stringstream ss;
87-
ss << Node::ToString(){"".join(format_pretty_print(arg) for arg in attr_args)};
88-
return ss.str();
88+
ss << Node::ToString();
89+
{"".join(format_pretty_print(arg) for arg in attr_args)} return ss.str();
8990
}}
9091
9192
private:
@@ -97,13 +98,16 @@ def format_arg_def(arg):
9798
name, stype = arg
9899
if stype == "Tensor": return "OpaqueXLATensor* " + name
99100
if stype == "Int64": return "int64_t " + name
100-
raise "problem unknown type: " + stype
101+
if stype == "Bool": return f"bool {name}"
102+
if stype == "ScalarType?": return f"Optional_XLAScalarType {name}"
103+
raise ValueError("problem unknown type: " + stype)
101104
def format_arg_ref(arg):
102105
name, stype = arg
103106
if stype == "Tensor": return name + "_ir_value"
104107
for extra in op["extras"]:
105108
if extra[0] == "canonicalize" and extra[1] == name:
106109
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
110+
if stype == "ScalarType?": return f"{name}.value()"
107111
return name
108112
def unpack_arg(arg):
109113
name, stype = arg
@@ -122,7 +126,7 @@ def snake_to_camel(name):
122126
return "".join(map(lambda x: x.capitalize(),name.split("_")))
123127

124128
def canonicalize_op(op):
125-
tokens = re.findall("(\w+|[\(\),:]|->)", op["def"])
129+
tokens = re.findall("(\w+\??|[\(\),:]|->)", op["def"])
126130
op["c_name"] = tokens[0]
127131
def expect(cond):
128132
if not cond: raise ValueError(f"""invalid format: {repr(op["def"])}""")

Sources/x10/swift_bindings/ops_list.txt

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
- def: "cumprod(input: Tensor, dim: Int64, dtype: ScalarType?, exclusive: Bool, reverse: Bool) -> Tensor"
2+
extras: ["canonicalize dim input"]
3+
x10_enum: at::aten::cumprod
4+
shape_fn: CumOpShapeFn
5+
lower_fn: LowerCumProd
6+
- def: "cumsum(input: Tensor, dim: Int64, dtype: ScalarType?, exclusive: Bool, reverse: Bool) -> Tensor"
7+
extras: ["canonicalize dim input"]
8+
x10_enum: at::aten::cumsum
9+
shape_fn: CumOpShapeFn
10+
lower_fn: LowerCumSum
111
- def: "log_softmax_backward(grad_output: Tensor, output: Tensor, dim: Int64) -> Tensor"
212
extras: ["canonicalize dim grad_output"]
313
x10_enum: at::aten::_log_softmax_backward_data

0 commit comments

Comments
 (0)