|
26 | 26 | #include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
|
27 | 27 | #include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h"
|
28 | 28 | #include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
|
| 29 | +#include "tensorflow/compiler/tf2xla/xla_tensor/matrix.h" |
| 30 | +#include "tensorflow/compiler/tf2xla/xla_tensor/nll_loss.h" |
29 | 31 | #include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
|
30 | 32 | #include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
|
31 | 33 | #include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
|
@@ -77,6 +79,19 @@ void OpFieldToString(std::ostream& stream, const char* field_name,
|
77 | 79 | else
|
78 | 80 | stream << value.toLong();
|
79 | 81 | }
|
| 82 | +const XLATensor* FirstTensor(OpaqueXLATensorArrayRef array) { |
| 83 | + XLA_CHECK_GE(array.size, 1); |
| 84 | + return array.data[0]; |
| 85 | +} |
| 86 | +std::vector<ir::Value> UnpackIrValues(OpaqueXLATensorArrayRef array) { |
| 87 | + std::vector<ir::Value> out; |
| 88 | + out.reserve(array.size); |
| 89 | + for (size_t i = 0; i < array.size; ++i) { |
| 90 | + out.push_back(array.data[i]->GetIrValue()); |
| 91 | + } |
| 92 | + return out; |
| 93 | +} |
| 94 | + |
80 | 95 | } // namespace swift_xla
|
81 | 96 |
|
82 | 97 | namespace swift_xla {
|
@@ -325,6 +340,89 @@ xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
|
325 | 340 | }
|
326 | 341 | }
|
327 | 342 |
|
| 343 | +xla::XlaOp LowerNllLoss(xla::XlaOp logits, xla::XlaOp labels, |
| 344 | + xla::int64 ignore_index) { |
| 345 | + return BuildNllLoss(logits, labels, absl::nullopt, ignore_index, |
| 346 | + ReductionMode::kMean); |
| 347 | +} |
| 348 | + |
| 349 | +xla::XlaOp LowerProd(xla::XlaOp input, |
| 350 | + const std::vector<xla::int64>& dimensions, |
| 351 | + bool keep_reduced_dimensions, |
| 352 | + c10::optional<at::ScalarType> dtype) { |
| 353 | + xla::XlaOp casted_input; |
| 354 | + if (dtype) { |
| 355 | + casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), |
| 356 | + MakeXlaPrimitiveType(*dtype, /*device=*/nullptr), |
| 357 | + /*device=*/nullptr); |
| 358 | + } else { |
| 359 | + casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); |
| 360 | + } |
| 361 | + return BuildProd(casted_input, dimensions, keep_reduced_dimensions); |
| 362 | +} |
| 363 | + |
| 364 | +xla::XlaOp LowerSelect(xla::XlaOp input, xla::int64 dim, xla::int64 index) { |
| 365 | + auto input_shape = XlaHelpers::ShapeOfXlaOp(input); |
| 366 | + index = |
| 367 | + XlaHelpers::GetCanonicalPosition(input_shape.dimensions(), dim, index); |
| 368 | + return SqueezeTrivialDimension( |
| 369 | + xla::SliceInDim(input, index, index + 1, 1, dim), dim); |
| 370 | +} |
| 371 | + |
| 372 | +std::vector<xla::XlaOp> GetArrayOperands(LoweringContext* loctx, |
| 373 | + absl::Span<const Output> operands, |
| 374 | + size_t offset) { |
| 375 | + std::vector<xla::XlaOp> inputs; |
| 376 | + operands = operands.subspan(offset); |
| 377 | + inputs.reserve(operands.size()); |
| 378 | + for (auto& operand : operands) { |
| 379 | + inputs.push_back(loctx->GetOutputOp(operand)); |
| 380 | + } |
| 381 | + return inputs; |
| 382 | +} |
| 383 | + |
| 384 | +std::vector<xla::XlaOp> MakeParameterList(xla::XlaBuilder* b, size_t offset, |
| 385 | + absl::Span<const Value> inputs, |
| 386 | + const char* name) { |
| 387 | + std::vector<xla::XlaOp> out; |
| 388 | + out.reserve(inputs.size()); |
| 389 | + for (size_t i = 0; i < inputs.size(); ++i) { |
| 390 | + out.push_back(xla::Parameter(b, offset + i, inputs[i].shape(), |
| 391 | + absl::StrCat(name, "_", i))); |
| 392 | + } |
| 393 | + return out; |
| 394 | +} |
| 395 | + |
| 396 | +std::vector<Value> TensorArgsConcat(absl::Span<const Value> inputa, |
| 397 | + absl::Span<const Value> inputb) { |
| 398 | + std::vector<Value> out; |
| 399 | + out.reserve(inputa.size() + inputb.size()); |
| 400 | + out.insert(out.end(), inputa.begin(), inputa.end()); |
| 401 | + out.insert(out.end(), inputb.begin(), inputb.end()); |
| 402 | + return out; |
| 403 | +} |
| 404 | + |
| 405 | +xla::int64 CanonicalizeStack(absl::Span<const Value> inputs, xla::int64 dim) { |
| 406 | + XLA_CHECK_GE(inputs.size(), 1); |
| 407 | + return swift_xla::XlaHelpers::GetCanonicalDimensionIndex( |
| 408 | + dim, inputs[0].shape().rank() + 1); |
| 409 | +} |
| 410 | + |
| 411 | +xla::int64 CanonicalizeCat(absl::Span<const Value> inputs, xla::int64 dim) { |
| 412 | + XLA_CHECK_GE(inputs.size(), 1); |
| 413 | + xla::Shape first_shape = inputs[0].shape(); |
| 414 | + dim = swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, |
| 415 | + first_shape.rank()); |
| 416 | + first_shape.DeleteDimension(dim); |
| 417 | + for (size_t i = 1; i < inputs.size(); ++i) { |
| 418 | + xla::Shape tensor_shape = inputs[i].shape(); |
| 419 | + tensor_shape.DeleteDimension(dim); |
| 420 | + XLA_CHECK(xla::ShapeUtil::Compatible(first_shape, tensor_shape)) |
| 421 | + << first_shape << " vs. " << tensor_shape; |
| 422 | + } |
| 423 | + return dim; |
| 424 | +} |
| 425 | + |
328 | 426 | } // namespace
|
329 | 427 | } // namespace ops
|
330 | 428 | } // namespace ir
|
|
0 commit comments