29
29
#include " tensorflow/compiler/tf2xla/xla_tensor/matrix.h"
30
30
#include " tensorflow/compiler/tf2xla/xla_tensor/nll_loss.h"
31
31
#include " tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
32
+ #include " tensorflow/compiler/tf2xla/xla_tensor/ops/tf_create_conv_attrs.h"
32
33
#include " tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
33
34
#include " tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
34
35
#include " tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
35
36
#include " tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
36
37
#include " tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
38
+ #include " tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
37
39
#include " tensorflow/compiler/tf2xla/lib/random.h"
38
40
#include " tensorflow/compiler/xla/client/lib/constants.h"
39
41
#include " tensorflow/compiler/xla/client/lib/math.h"
40
42
#include " tensorflow/compiler/xla/client/lib/prng.h"
41
43
#include " tensorflow/compiler/xla/client/lib/qr.h"
42
44
#include " tensorflow/compiler/xla/client/lib/svd.h"
45
+ #include " tensorflow/compiler/xla/service/hlo_instruction.h"
43
46
#include " xla_tensor_wrapper.h"
44
47
45
48
namespace at {
@@ -51,18 +54,40 @@ xla::hash_t Hash(const at::Scalar& value) {
51
54
: xla::util::Hash (value.toLong ());
52
55
}
53
56
}
57
+ namespace tensorflow {
58
+ xla::hash_t Hash (tensorflow::MirrorPadMode mode) {
59
+ return xla::util::Hash (static_cast <int >(mode));
60
+ }
61
+ } // namespace tensorflow
62
+ namespace xla {
63
+ xla::hash_t Hash (const xla::PaddingConfig& padding_config) {
64
+ std::vector<xla::int64> low;
65
+ std::vector<xla::int64> high;
66
+ std::vector<xla::int64> interior;
67
+ for (const xla::PaddingConfig::PaddingConfigDimension& dim_padding :
68
+ padding_config.dimensions ()) {
69
+ low.push_back (dim_padding.edge_padding_low ());
70
+ high.push_back (dim_padding.edge_padding_high ());
71
+ interior.push_back (dim_padding.interior_padding ());
72
+ }
73
+ return xla::util::MHash (low, high, interior);
74
+ }
75
+ } // namespace xla
54
76
namespace swift_xla {
55
77
void OpFieldToString (std::ostream& stream, const char * field_name, const c10::optional<at::ScalarType>& dtype) {
56
78
if (dtype) stream << " , " << field_name << " =" << *dtype;
57
79
}
58
- void OpFieldToString (std::ostream& stream, const char * field_name, bool value) {
80
+ template <typename T>
81
+ void OpFieldToString (std::ostream& stream, const char * field_name, T value) {
59
82
stream << " , " << field_name << " =" << value;
60
83
}
61
- void OpFieldToString (std::ostream& stream, const char * field_name, xla::int64 value) {
62
- stream << " , " << field_name << " =" << value;
84
+ void OpFieldToString (std::ostream& stream, const char * field_name,
85
+ tensorflow::MirrorPadMode value) {
86
+ stream << " , " << field_name << " =" << static_cast <int >(value);
63
87
}
64
- void OpFieldToString (std::ostream& stream, const char * field_name, float value) {
65
- stream << " , " << field_name << " =" << value;
88
+ void OpFieldToString (std::ostream& stream, const char * field_name,
89
+ const xla::PaddingConfig& value) {
90
+ stream << " , " << field_name << " =" << xla::PaddingConfigToString (value);
66
91
}
67
92
void OpFieldToString (std::ostream& stream, const char * field_name,
68
93
const std::vector<xla::int64>& value) {
@@ -219,12 +244,18 @@ std::vector<xla::int64> CanonicalizeExpand(xla::Shape shape,
219
244
return dimensions;
220
245
}
221
246
222
- xla::XlaOp LowerPad (xla::XlaOp input, absl::Span< const xla::int64> pad ,
223
- const at::Scalar& value ) {
247
+ xla::XlaOp LowerPad (xla::XlaOp input, const at::Scalar& value ,
248
+ const xla::PaddingConfig& config ) {
224
249
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
225
250
return xla::Pad (input,
226
251
XlaHelpers::ScalarValue (value, input_shape.element_type (),
227
252
input.builder ()),
253
+ config);
254
+ }
255
+
256
+ xla::XlaOp LowerPad (xla::XlaOp input, absl::Span<const xla::int64> pad,
257
+ const at::Scalar& value) {
258
+ return LowerPad (input, value,
228
259
XlaHelpers::MakeXlaPaddingConfigFromNdPadding (pad));
229
260
}
230
261
@@ -267,6 +298,11 @@ xla::XlaOp LowerWhere(xla::XlaOp condition, xla::XlaOp input,
267
298
return xla::Select (pred_condition, input, other);
268
299
}
269
300
301
+ std::vector<xla::XlaOp> BuildTopK (xla::XlaOp input, xla::int64 k,
302
+ xla::int64 dim, bool largest) {
303
+ return CreateTopK (input, k, dim, largest, true );
304
+ }
305
+
270
306
xla::XlaOp BuildOneHot (xla::XlaOp indices, xla::XlaOp on_value,
271
307
xla::XlaOp off_value, xla::int64 depth,
272
308
xla::int64 axis) {
@@ -474,9 +510,120 @@ xla::Shape ShapeOfXlaOpList(absl::Span<const xla::XlaOp> ops) {
474
510
return result;
475
511
}
476
512
513
+ xla::XlaOp BuildTfConv (xla::XlaOp input, xla::XlaOp filter, bool depthwise,
514
+ absl::Span<const xla::int64> strides,
515
+ tensorflow::Padding padding,
516
+ absl::Span<const xla::int64> explicit_paddings,
517
+ tensorflow::TensorFormat data_format,
518
+ absl::Span<const xla::int64> dilations) {
519
+ xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp (input);
520
+ int num_spatial_dims = input_shape.rank () - 2 ;
521
+ tensorflow::ConvOpAttrs attrs =
522
+ CreateConvOpAttrs (num_spatial_dims, depthwise, strides, padding,
523
+ explicit_paddings, data_format, dilations);
524
+ xla::PrecisionConfig precision_config =
525
+ XlaHelpers::BuildPrecisionConfig (XlaHelpers::mat_mul_precision ());
526
+ return ConsumeValue (tensorflow::MakeXlaForwardConvOp (
527
+ /* type_string=*/ " TfConv" , /* conv_input=*/ input, /* filter=*/ filter,
528
+ /* attrs=*/ attrs, /* precision_config=*/ &precision_config));
529
+ }
530
+
531
+ xla::XlaOp BuildTfConvBackpropFilter (
532
+ xla::XlaOp input, absl::Span<const xla::int64> filter_sizes,
533
+ xla::XlaOp out_backprop, bool depthwise,
534
+ absl::Span<const xla::int64> strides, tensorflow::Padding padding,
535
+ absl::Span<const xla::int64> explicit_paddings,
536
+ tensorflow::TensorFormat data_format,
537
+ absl::Span<const xla::int64> dilations) {
538
+ int num_spatial_dims = filter_sizes.size () - 2 ;
539
+ tensorflow::ConvOpAttrs attrs =
540
+ CreateConvOpAttrs (num_spatial_dims, depthwise, strides, padding,
541
+ explicit_paddings, data_format, dilations);
542
+ xla::PrecisionConfig precision_config =
543
+ XlaHelpers::BuildPrecisionConfig (XlaHelpers::mat_mul_precision ());
544
+ xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp (input);
545
+ return ConsumeValue (tensorflow::MakeXlaBackpropFilterConvOp (
546
+ /* type_string=*/ " TfConvBackpropFilter" , /* activations=*/ input,
547
+ /* filter_shape=*/
548
+ xla::ShapeUtil::MakeShape (input_shape.element_type (), filter_sizes),
549
+ /* gradients=*/ out_backprop,
550
+ /* attrs=*/ attrs, /* precision_config=*/ &precision_config));
551
+ }
552
+
553
+ xla::XlaOp BuildTfConvBackpropInput (
554
+ absl::Span<const xla::int64> input_sizes, xla::XlaOp filter,
555
+ xla::XlaOp out_backprop, bool depthwise,
556
+ absl::Span<const xla::int64> strides, tensorflow::Padding padding,
557
+ absl::Span<const xla::int64> explicit_paddings,
558
+ tensorflow::TensorFormat data_format,
559
+ absl::Span<const xla::int64> dilations) {
560
+ int num_spatial_dims = input_sizes.size () - 2 ;
561
+ tensorflow::ConvOpAttrs attrs =
562
+ CreateConvOpAttrs (num_spatial_dims, depthwise, strides, padding,
563
+ explicit_paddings, data_format, dilations);
564
+ xla::PrecisionConfig precision_config =
565
+ XlaHelpers::BuildPrecisionConfig (XlaHelpers::mat_mul_precision ());
566
+ xla::Shape filter_shape = XlaHelpers::ShapeOfXlaOp (filter);
567
+ return ConsumeValue (tensorflow::MakeXlaBackpropInputConvOp (
568
+ /* type_string=*/ " TfConvBackpropInput" ,
569
+ /* input_shape=*/
570
+ xla::ShapeUtil::MakeShape (filter_shape.element_type (), input_sizes),
571
+ /* filter=*/ filter,
572
+ /* out_backprop=*/ out_backprop,
573
+ /* attrs=*/ attrs, /* precision_config=*/ &precision_config));
574
+ }
575
+
477
576
} // namespace
478
577
} // namespace ops
479
578
} // namespace ir
480
579
} // namespace swift_xla
481
580
581
+ namespace {
582
+
583
+ tensorflow::Padding ToTFPadding (TFPadding padding) {
584
+ switch (padding) {
585
+ case TFPadding_VALID: {
586
+ return tensorflow::VALID;
587
+ }
588
+ case TFPadding_SAME: {
589
+ return tensorflow::SAME;
590
+ }
591
+ case TFPadding_EXPLICIT: {
592
+ return tensorflow::EXPLICIT;
593
+ }
594
+ default : {
595
+ LOG (FATAL) << " Invalid padding: " << padding;
596
+ }
597
+ }
598
+ }
599
+
600
+ tensorflow::MirrorPadMode ToTFMirrorPadMode (TFMirrorPadMode mode) {
601
+ switch (mode) {
602
+ case TFMirrorPadMode_REFLECT: {
603
+ return tensorflow::MirrorPadMode::REFLECT;
604
+ }
605
+ case TFMirrorPadMode_SYMMETRIC: {
606
+ return tensorflow::MirrorPadMode::SYMMETRIC;
607
+ }
608
+ default : {
609
+ LOG (FATAL) << " Invalid mirror pad mode: " << mode;
610
+ }
611
+ }
612
+ }
613
+
614
+ xla::PaddingConfig ToXLAPaddingConfig (PaddingConfig padding_config) {
615
+ xla::PaddingConfig xla_padding_config;
616
+ for (size_t i = 0 ; i < padding_config.count ; ++i) {
617
+ xla::PaddingConfig::PaddingConfigDimension* dims =
618
+ xla_padding_config.add_dimensions ();
619
+ const PaddingConfigDimension& padding_dim = padding_config.dimensions [i];
620
+ dims->set_edge_padding_low (padding_dim.edge_padding_low );
621
+ dims->set_edge_padding_high (padding_dim.edge_padding_high );
622
+ dims->set_interior_padding (padding_dim.interior_padding );
623
+ }
624
+ return xla_padding_config;
625
+ }
626
+
627
+ } // namespace
628
+
482
629
#include " xla_tensor_ops_wrapper_generated.cc.inc"
0 commit comments