Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c086579

Browse files
authored
Convert some more xla ops to be generated from a spec. (#1078)
Ops include: cat diagonal_value dynamic_slice dynamic_update_slice nll_loss permute_value physical_cast prod repeat resize_value round_to_even select stack
1 parent c4e338f commit c086579

7 files changed

+807
-69
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

+98
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
2727
#include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h"
2828
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
29+
#include "tensorflow/compiler/tf2xla/xla_tensor/matrix.h"
30+
#include "tensorflow/compiler/tf2xla/xla_tensor/nll_loss.h"
2931
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
3032
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
3133
#include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
@@ -77,6 +79,19 @@ void OpFieldToString(std::ostream& stream, const char* field_name,
7779
else
7880
stream << value.toLong();
7981
}
82+
const XLATensor* FirstTensor(OpaqueXLATensorArrayRef array) {
83+
XLA_CHECK_GE(array.size, 1);
84+
return array.data[0];
85+
}
86+
std::vector<ir::Value> UnpackIrValues(OpaqueXLATensorArrayRef array) {
87+
std::vector<ir::Value> out;
88+
out.reserve(array.size);
89+
for (size_t i = 0; i < array.size; ++i) {
90+
out.push_back(array.data[i]->GetIrValue());
91+
}
92+
return out;
93+
}
94+
8095
} // namespace swift_xla
8196

8297
namespace swift_xla {
@@ -325,6 +340,89 @@ xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
325340
}
326341
}
327342

343+
xla::XlaOp LowerNllLoss(xla::XlaOp logits, xla::XlaOp labels,
344+
xla::int64 ignore_index) {
345+
return BuildNllLoss(logits, labels, absl::nullopt, ignore_index,
346+
ReductionMode::kMean);
347+
}
348+
349+
xla::XlaOp LowerProd(xla::XlaOp input,
350+
const std::vector<xla::int64>& dimensions,
351+
bool keep_reduced_dimensions,
352+
c10::optional<at::ScalarType> dtype) {
353+
xla::XlaOp casted_input;
354+
if (dtype) {
355+
casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
356+
MakeXlaPrimitiveType(*dtype, /*device=*/nullptr),
357+
/*device=*/nullptr);
358+
} else {
359+
casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input));
360+
}
361+
return BuildProd(casted_input, dimensions, keep_reduced_dimensions);
362+
}
363+
364+
xla::XlaOp LowerSelect(xla::XlaOp input, xla::int64 dim, xla::int64 index) {
365+
auto input_shape = XlaHelpers::ShapeOfXlaOp(input);
366+
index =
367+
XlaHelpers::GetCanonicalPosition(input_shape.dimensions(), dim, index);
368+
return SqueezeTrivialDimension(
369+
xla::SliceInDim(input, index, index + 1, 1, dim), dim);
370+
}
371+
372+
std::vector<xla::XlaOp> GetArrayOperands(LoweringContext* loctx,
373+
absl::Span<const Output> operands,
374+
size_t offset) {
375+
std::vector<xla::XlaOp> inputs;
376+
operands = operands.subspan(offset);
377+
inputs.reserve(operands.size());
378+
for (auto& operand : operands) {
379+
inputs.push_back(loctx->GetOutputOp(operand));
380+
}
381+
return inputs;
382+
}
383+
384+
std::vector<xla::XlaOp> MakeParameterList(xla::XlaBuilder* b, size_t offset,
385+
absl::Span<const Value> inputs,
386+
const char* name) {
387+
std::vector<xla::XlaOp> out;
388+
out.reserve(inputs.size());
389+
for (size_t i = 0; i < inputs.size(); ++i) {
390+
out.push_back(xla::Parameter(b, offset + i, inputs[i].shape(),
391+
absl::StrCat(name, "_", i)));
392+
}
393+
return out;
394+
}
395+
396+
std::vector<Value> TensorArgsConcat(absl::Span<const Value> inputa,
397+
absl::Span<const Value> inputb) {
398+
std::vector<Value> out;
399+
out.reserve(inputa.size() + inputb.size());
400+
out.insert(out.end(), inputa.begin(), inputa.end());
401+
out.insert(out.end(), inputb.begin(), inputb.end());
402+
return out;
403+
}
404+
405+
xla::int64 CanonicalizeStack(absl::Span<const Value> inputs, xla::int64 dim) {
406+
XLA_CHECK_GE(inputs.size(), 1);
407+
return swift_xla::XlaHelpers::GetCanonicalDimensionIndex(
408+
dim, inputs[0].shape().rank() + 1);
409+
}
410+
411+
xla::int64 CanonicalizeCat(absl::Span<const Value> inputs, xla::int64 dim) {
412+
XLA_CHECK_GE(inputs.size(), 1);
413+
xla::Shape first_shape = inputs[0].shape();
414+
dim = swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim,
415+
first_shape.rank());
416+
first_shape.DeleteDimension(dim);
417+
for (size_t i = 1; i < inputs.size(); ++i) {
418+
xla::Shape tensor_shape = inputs[i].shape();
419+
tensor_shape.DeleteDimension(dim);
420+
XLA_CHECK(xla::ShapeUtil::Compatible(first_shape, tensor_shape))
421+
<< first_shape << " vs. " << tensor_shape;
422+
}
423+
return dim;
424+
}
425+
328426
} // namespace
329427
} // namespace ops
330428
} // namespace ir

0 commit comments

Comments
 (0)