Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 7a2aaf4

Browse files
authored
Add _Raw.replicaId() op. (#1018)
1 parent d029d4f commit 7a2aaf4

9 files changed

+104
-0
lines changed

Sources/CX10/xla_tensor_wrapper.cc

+3
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,9 @@ OpaqueXLATensor* XLATensor_repeat(OpaqueXLATensor* input,
572572
return new XLATensor(
573573
XLATensor::repeat(*input, XlaHelpers::I64List(repeats.slice())));
574574
}
575+
OpaqueXLATensor* XLATensor_replica_id(const struct CDevice device) {
576+
return new XLATensor(XLATensor::xla_replica_id(ConvertDevice(device)));
577+
}
575578
OpaqueXLATensor* XLATensor_resize_value(OpaqueXLATensor* a, Int64ArrayRef arr) {
576579
return new XLATensor(
577580
XLATensor::resize_value(*a, XlaHelpers::I64List(arr.slice())));

Sources/CX10/xla_tensor_wrapper.h

+1
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ XLA_API OpaqueXLATensor* XLATensor_relu(OpaqueXLATensor* a);
341341
XLA_API OpaqueXLATensor* XLATensor_rem(OpaqueXLATensor* a, OpaqueXLATensor* b);
342342
XLA_API OpaqueXLATensor* XLATensor_repeat(OpaqueXLATensor* input,
343343
Int64ArrayRef repeats);
344+
XLA_API OpaqueXLATensor* XLATensor_replica_id(const struct CDevice device);
344345
XLA_API OpaqueXLATensor*
345346
XLATensor_resize_value(OpaqueXLATensor* a, Int64ArrayRef arr);
346347
XLA_API OpaqueXLATensor* XLATensor_round_to_even(OpaqueXLATensor* a);

Sources/x10/swift_bindings/XLATensor.swift

+4
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,10 @@ extension XLATensor {
743743
return XLATensor(_handle: XLATensor_relu(a.handle))
744744
}
745745

746+
static func replica_id(_ device: Device) -> XLATensor {
747+
return XLATensor(_handle: XLATensor_replica_id(device.cdevice));
748+
}
749+
746750
static func resize_value(_ value: XLATensor, _ dims: [Int64]) -> XLATensor {
747751
defer { _fixLifetime(value) }
748752
return dims.withArrayRef { dims in

Sources/x10/swift_bindings/apis/RawOpsManual.swift

+4
Original file line numberDiff line numberDiff line change
@@ -2929,6 +2929,10 @@ public enum _RawXLA {
29292929
return Tensor(_xla: XLATensor.threshold_backward(gradients.xlaTensor, features.xlaTensor, 0))
29302930
}
29312931

2932+
public static func replicaId(_ device: Device) -> Tensor<Int32> {
2933+
return Tensor(_xla: XLATensor.replica_id(device))
2934+
}
2935+
29322936
/// Reshapes a tensor.
29332937
///
29342938
/// Given `tensor`, this operation returns a tensor that has the same values

Sources/x10/xla_tensor/aten_compat.h

+1
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@
757757
_(aten, xla_max_pool_grad) \
758758
_(aten, xla_pad) \
759759
_(aten, xla_rem) \
760+
_(aten, xla_replica_id) \
760761
_(aten, xla_slice) \
761762
_(aten, xla_truncated_normal) \
762763
_(aten, xla_is_finite) \
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2020 TensorFlow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/replica_id.h"
16+
17+
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
18+
19+
namespace swift_xla {
20+
namespace ir {
21+
namespace ops {
22+
23+
ReplicaId::ReplicaId() : Node(
24+
ir::OpKind(at::aten::xla_replica_id),
25+
{},
26+
xla::ShapeUtil::MakeShape(xla::S32, {}),
27+
1, 0x18923728) {}
28+
29+
NodePtr ReplicaId::Clone(OpList operands) const {
30+
return MakeNode<ReplicaId>();
31+
}
32+
33+
XlaOpVector ReplicaId::Lower(LoweringContext* loctx) const {
34+
return ReturnOp(xla::ConvertElementType(
35+
xla::ReplicaId(loctx->builder()), xla::S32), loctx);
36+
}
37+
38+
std::string ReplicaId::ToString() const {
39+
std::stringstream ss;
40+
ss << Node::ToString();
41+
return ss.str();
42+
}
43+
44+
} // namespace ops
45+
} // namespace ir
46+
} // namespace swift_xla
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright 2020 TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h"
20+
21+
namespace swift_xla {
22+
namespace ir {
23+
namespace ops {
24+
25+
class ReplicaId : public Node {
26+
public:
27+
ReplicaId();
28+
29+
std::string ToString() const override;
30+
31+
NodePtr Clone(OpList operands) const override;
32+
33+
XlaOpVector Lower(LoweringContext* loctx) const override;
34+
};
35+
36+
} // namespace ops
37+
} // namespace ir
38+
} // namespace swift_xla

Sources/x10/xla_tensor/tensor.h

+2
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,8 @@ class XLATensor {
12641264

12651265
static XLATensor xla_truncated_normal(const XLATensor& input);
12661266

1267+
static XLATensor xla_replica_id(const Device& device);
1268+
12671269
private:
12681270
struct SyncTensorsConfig {
12691271
// Whether we want to force XLA data on the target tensors (hence trimming

Sources/x10/xla_tensor/tensor_methods.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/reflection_pad2d.h"
9696
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/reflection_pad2d_backward.h"
9797
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/repeat.h"
98+
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/replica_id.h"
9899
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/replication_pad.h"
99100
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/replication_pad_backward.h"
100101
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/resize.h"
@@ -3037,4 +3038,8 @@ XLATensor XLATensor::xla_truncated_normal(const XLATensor& input) {
30373038
return input.CreateFrom(ir::ops::XlaTruncatedNormal(input.GetIrValue()));
30383039
}
30393040

3041+
XLATensor XLATensor::xla_replica_id(const Device& device) {
3042+
return XLATensor::Create(ir::MakeNode<ir::ops::ReplicaId>(), device);
3043+
}
3044+
30403045
} // namespace swift_xla

0 commit comments

Comments
 (0)