Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b719d99

Browse files
authored
Convert more ops to be generated from a spec. (#1083)
This includes: index topk tf_Conv tf_ConvBackpropFilter tf_ConvBackpropInput tf_MirrorPad tf_MirrorPadGrad xla_pad.
1 parent 2aa292e commit b719d99

8 files changed

+737
-88
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

+154-7
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@
2929
#include "tensorflow/compiler/tf2xla/xla_tensor/matrix.h"
3030
#include "tensorflow/compiler/tf2xla/xla_tensor/nll_loss.h"
3131
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
32+
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/tf_create_conv_attrs.h"
3233
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
3334
#include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
3435
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
3536
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
3637
#include "tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
38+
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
3739
#include "tensorflow/compiler/tf2xla/lib/random.h"
3840
#include "tensorflow/compiler/xla/client/lib/constants.h"
3941
#include "tensorflow/compiler/xla/client/lib/math.h"
4042
#include "tensorflow/compiler/xla/client/lib/prng.h"
4143
#include "tensorflow/compiler/xla/client/lib/qr.h"
4244
#include "tensorflow/compiler/xla/client/lib/svd.h"
45+
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
4346
#include "xla_tensor_wrapper.h"
4447

4548
namespace at {
@@ -51,18 +54,40 @@ xla::hash_t Hash(const at::Scalar& value) {
5154
: xla::util::Hash(value.toLong());
5255
}
5356
}
57+
namespace tensorflow {
58+
xla::hash_t Hash(tensorflow::MirrorPadMode mode) {
59+
return xla::util::Hash(static_cast<int>(mode));
60+
}
61+
} // namespace tensorflow
62+
namespace xla {
63+
xla::hash_t Hash(const xla::PaddingConfig& padding_config) {
64+
std::vector<xla::int64> low;
65+
std::vector<xla::int64> high;
66+
std::vector<xla::int64> interior;
67+
for (const xla::PaddingConfig::PaddingConfigDimension& dim_padding :
68+
padding_config.dimensions()) {
69+
low.push_back(dim_padding.edge_padding_low());
70+
high.push_back(dim_padding.edge_padding_high());
71+
interior.push_back(dim_padding.interior_padding());
72+
}
73+
return xla::util::MHash(low, high, interior);
74+
}
75+
} // namespace xla
5476
namespace swift_xla {
5577
void OpFieldToString(std::ostream& stream, const char* field_name, const c10::optional<at::ScalarType>& dtype) {
5678
if (dtype) stream << ", " << field_name << "=" << *dtype;
5779
}
58-
void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
80+
template <typename T>
81+
void OpFieldToString(std::ostream& stream, const char* field_name, T value) {
5982
stream << ", " << field_name << "=" << value;
6083
}
61-
void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 value) {
62-
stream << ", " << field_name << "=" << value;
84+
void OpFieldToString(std::ostream& stream, const char* field_name,
85+
tensorflow::MirrorPadMode value) {
86+
stream << ", " << field_name << "=" << static_cast<int>(value);
6387
}
64-
void OpFieldToString(std::ostream& stream, const char* field_name, float value) {
65-
stream << ", " << field_name << "=" << value;
88+
void OpFieldToString(std::ostream& stream, const char* field_name,
89+
const xla::PaddingConfig& value) {
90+
stream << ", " << field_name << "=" << xla::PaddingConfigToString(value);
6691
}
6792
void OpFieldToString(std::ostream& stream, const char* field_name,
6893
const std::vector<xla::int64>& value) {
@@ -219,12 +244,18 @@ std::vector<xla::int64> CanonicalizeExpand(xla::Shape shape,
219244
return dimensions;
220245
}
221246

222-
xla::XlaOp LowerPad(xla::XlaOp input, absl::Span<const xla::int64> pad,
223-
const at::Scalar& value) {
247+
xla::XlaOp LowerPad(xla::XlaOp input, const at::Scalar& value,
248+
const xla::PaddingConfig& config) {
224249
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
225250
return xla::Pad(input,
226251
XlaHelpers::ScalarValue(value, input_shape.element_type(),
227252
input.builder()),
253+
config);
254+
}
255+
256+
xla::XlaOp LowerPad(xla::XlaOp input, absl::Span<const xla::int64> pad,
257+
const at::Scalar& value) {
258+
return LowerPad(input, value,
228259
XlaHelpers::MakeXlaPaddingConfigFromNdPadding(pad));
229260
}
230261

@@ -267,6 +298,11 @@ xla::XlaOp LowerWhere(xla::XlaOp condition, xla::XlaOp input,
267298
return xla::Select(pred_condition, input, other);
268299
}
269300

301+
std::vector<xla::XlaOp> BuildTopK(xla::XlaOp input, xla::int64 k,
302+
xla::int64 dim, bool largest) {
303+
return CreateTopK(input, k, dim, largest, true);
304+
}
305+
270306
xla::XlaOp BuildOneHot(xla::XlaOp indices, xla::XlaOp on_value,
271307
xla::XlaOp off_value, xla::int64 depth,
272308
xla::int64 axis) {
@@ -474,9 +510,120 @@ xla::Shape ShapeOfXlaOpList(absl::Span<const xla::XlaOp> ops) {
474510
return result;
475511
}
476512

513+
xla::XlaOp BuildTfConv(xla::XlaOp input, xla::XlaOp filter, bool depthwise,
514+
absl::Span<const xla::int64> strides,
515+
tensorflow::Padding padding,
516+
absl::Span<const xla::int64> explicit_paddings,
517+
tensorflow::TensorFormat data_format,
518+
absl::Span<const xla::int64> dilations) {
519+
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
520+
int num_spatial_dims = input_shape.rank() - 2;
521+
tensorflow::ConvOpAttrs attrs =
522+
CreateConvOpAttrs(num_spatial_dims, depthwise, strides, padding,
523+
explicit_paddings, data_format, dilations);
524+
xla::PrecisionConfig precision_config =
525+
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
526+
return ConsumeValue(tensorflow::MakeXlaForwardConvOp(
527+
/*type_string=*/"TfConv", /*conv_input=*/input, /*filter=*/filter,
528+
/*attrs=*/attrs, /*precision_config=*/&precision_config));
529+
}
530+
531+
xla::XlaOp BuildTfConvBackpropFilter(
532+
xla::XlaOp input, absl::Span<const xla::int64> filter_sizes,
533+
xla::XlaOp out_backprop, bool depthwise,
534+
absl::Span<const xla::int64> strides, tensorflow::Padding padding,
535+
absl::Span<const xla::int64> explicit_paddings,
536+
tensorflow::TensorFormat data_format,
537+
absl::Span<const xla::int64> dilations) {
538+
int num_spatial_dims = filter_sizes.size() - 2;
539+
tensorflow::ConvOpAttrs attrs =
540+
CreateConvOpAttrs(num_spatial_dims, depthwise, strides, padding,
541+
explicit_paddings, data_format, dilations);
542+
xla::PrecisionConfig precision_config =
543+
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
544+
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
545+
return ConsumeValue(tensorflow::MakeXlaBackpropFilterConvOp(
546+
/*type_string=*/"TfConvBackpropFilter", /*activations=*/input,
547+
/*filter_shape=*/
548+
xla::ShapeUtil::MakeShape(input_shape.element_type(), filter_sizes),
549+
/*gradients=*/out_backprop,
550+
/*attrs=*/attrs, /*precision_config=*/&precision_config));
551+
}
552+
553+
xla::XlaOp BuildTfConvBackpropInput(
554+
absl::Span<const xla::int64> input_sizes, xla::XlaOp filter,
555+
xla::XlaOp out_backprop, bool depthwise,
556+
absl::Span<const xla::int64> strides, tensorflow::Padding padding,
557+
absl::Span<const xla::int64> explicit_paddings,
558+
tensorflow::TensorFormat data_format,
559+
absl::Span<const xla::int64> dilations) {
560+
int num_spatial_dims = input_sizes.size() - 2;
561+
tensorflow::ConvOpAttrs attrs =
562+
CreateConvOpAttrs(num_spatial_dims, depthwise, strides, padding,
563+
explicit_paddings, data_format, dilations);
564+
xla::PrecisionConfig precision_config =
565+
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
566+
xla::Shape filter_shape = XlaHelpers::ShapeOfXlaOp(filter);
567+
return ConsumeValue(tensorflow::MakeXlaBackpropInputConvOp(
568+
/*type_string=*/"TfConvBackpropInput",
569+
/*input_shape=*/
570+
xla::ShapeUtil::MakeShape(filter_shape.element_type(), input_sizes),
571+
/*filter=*/filter,
572+
/*out_backprop=*/out_backprop,
573+
/*attrs=*/attrs, /*precision_config=*/&precision_config));
574+
}
575+
477576
} // namespace
478577
} // namespace ops
479578
} // namespace ir
480579
} // namespace swift_xla
481580

581+
namespace {
582+
583+
tensorflow::Padding ToTFPadding(TFPadding padding) {
584+
switch (padding) {
585+
case TFPadding_VALID: {
586+
return tensorflow::VALID;
587+
}
588+
case TFPadding_SAME: {
589+
return tensorflow::SAME;
590+
}
591+
case TFPadding_EXPLICIT: {
592+
return tensorflow::EXPLICIT;
593+
}
594+
default: {
595+
LOG(FATAL) << "Invalid padding: " << padding;
596+
}
597+
}
598+
}
599+
600+
tensorflow::MirrorPadMode ToTFMirrorPadMode(TFMirrorPadMode mode) {
601+
switch (mode) {
602+
case TFMirrorPadMode_REFLECT: {
603+
return tensorflow::MirrorPadMode::REFLECT;
604+
}
605+
case TFMirrorPadMode_SYMMETRIC: {
606+
return tensorflow::MirrorPadMode::SYMMETRIC;
607+
}
608+
default: {
609+
LOG(FATAL) << "Invalid mirror pad mode: " << mode;
610+
}
611+
}
612+
}
613+
614+
xla::PaddingConfig ToXLAPaddingConfig(PaddingConfig padding_config) {
615+
xla::PaddingConfig xla_padding_config;
616+
for (size_t i = 0; i < padding_config.count; ++i) {
617+
xla::PaddingConfig::PaddingConfigDimension* dims =
618+
xla_padding_config.add_dimensions();
619+
const PaddingConfigDimension& padding_dim = padding_config.dimensions[i];
620+
dims->set_edge_padding_low(padding_dim.edge_padding_low);
621+
dims->set_edge_padding_high(padding_dim.edge_padding_high);
622+
dims->set_interior_padding(padding_dim.interior_padding);
623+
}
624+
return xla_padding_config;
625+
}
626+
627+
} // namespace
628+
482629
#include "xla_tensor_ops_wrapper_generated.cc.inc"

0 commit comments

Comments
 (0)