Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d0ea5cb

Browse files
authoredOct 23, 2020
Some more XLATensor cleanup. (#1110)
1 parent d91e91e commit d0ea5cb

9 files changed

+441
-292
lines changed
 

‎Sources/CX10/xla_tensor_ops_wrapper.cc

+30-3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ xla::XlaOp LowerBinaryValueOp(xla::XlaOp lhs, xla::XlaOp rhs) {
144144
return T(lhs, rhs);
145145
}
146146

147+
std::vector<xla::XlaOp> LowerBroadcastTensors(xla::XlaOp lhs, xla::XlaOp rhs) {
148+
std::tie(lhs, rhs) = XlaHelpers::PromoteValues(lhs, rhs);
149+
return {lhs, rhs};
150+
}
151+
147152
xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) {
148153
if (dim == -1) return SqueezeAllTrivialDimensions(input);
149154
XLA_CHECK_GE(dim, 0);
@@ -324,9 +329,7 @@ xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices,
324329
combine);
325330
}
326331

327-
xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
328-
xla::XlaOp minval, xla::XlaOp maxval,
329-
LoweringContext* loctx = nullptr) {
332+
xla::BitGeneratorTy GetBestGenerator(LoweringContext* loctx = nullptr) {
330333
xla::BitGeneratorTy generator;
331334
if (!loctx || loctx->device().hw_type == swift_xla::DeviceType::TPU) {
332335
generator = xla::ThreeFryBitGenerator;
@@ -336,6 +339,30 @@ xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
336339
return xla::PhiloxBitGenerator(key, state, shape);
337340
};
338341
}
342+
return generator;
343+
}
344+
345+
xla::XlaOp LowerTfStatelessRandomNormal(xla::Shape shape, xla::XlaOp seeds,
346+
at::ScalarType dtype,
347+
LoweringContext* loctx = nullptr) {
348+
auto generator = GetBestGenerator(loctx);
349+
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
350+
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
351+
xla::XlaOp initial_state =
352+
xla::ConstantR0WithType(seeds.builder(), xla::U64, 0);
353+
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
354+
ShiftLeft(ConvertElementType(seed1, xla::U64),
355+
ConstantR0WithType(seeds.builder(), xla::U64, 32));
356+
xla::XlaOp normal =
357+
xla::NormalFloatingPointDistribution(key, initial_state, generator, shape)
358+
.value;
359+
return normal;
360+
}
361+
362+
xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
363+
xla::XlaOp minval, xla::XlaOp maxval,
364+
LoweringContext* loctx = nullptr) {
365+
auto generator = GetBestGenerator(loctx);
339366
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
340367
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
341368
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |

‎Sources/CX10/xla_tensor_ops_wrapper_generated.cc.inc

+191-78
Large diffs are not rendered by default.

‎Sources/CX10/xla_tensor_wrapper.cc

-19
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,6 @@ OpaqueXLATensor* XLATensor_arange(XLAScalar start, XLAScalar end,
296296
ToScalarType(type));
297297
return new XLATensor(out);
298298
}
299-
OpaqueXLATensor_pair XLATensor_broadcast_tensors(OpaqueXLATensor* a,
300-
OpaqueXLATensor* b) {
301-
OpaqueXLATensor_pair result;
302-
auto output = XLATensor::broadcast_tensors({*a, *b});
303-
result.x = new XLATensor(output[0]);
304-
result.y = new XLATensor(output[1]);
305-
return result;
306-
}
307299
OpaqueXLATensorArrayRef XLATensor_cross_replica_sum(
308300
OpaqueXLATensorArrayRef inputs, double scale) {
309301
auto token = swift_xla::ir::MakeNode<swift_xla::ir::ops::Token>();
@@ -334,17 +326,6 @@ OpaqueXLATensor* XLATensor_linspace(XLAScalar start, XLAScalar stop,
334326
OpaqueXLATensor* XLATensor_replica_id(const struct CDevice device) {
335327
return new XLATensor(XLATensor::xla_replica_id(ConvertDevice(device)));
336328
}
337-
OpaqueXLATensor* XLATensor_tf_StatelessRandomNormal(
338-
Int64ArrayRef size, OpaqueXLATensor* seeds, const struct CDevice device,
339-
enum XLATensorScalarType type) {
340-
return new XLATensor(XLATensor::tf_StatelessRandomNormal(
341-
size.slice(), *seeds, ConvertDevice(device), ToScalarType(type)));
342-
}
343-
OpaqueXLATensor* XLATensor_threshold_backward(OpaqueXLATensor* grad_output,
344-
OpaqueXLATensor* input,
345-
float threshold) {
346-
return XLATensor_threshold(input, grad_output, threshold, 0);
347-
}
348329
OpaqueXLATensor* XLATensor_to(OpaqueXLATensor* a, const CDevice* device,
349330
Optional_XLAScalarType dtype) {
350331
return new XLATensor(XLATensor::to(*a, AsOptional(device), dtype.value()));

‎Sources/CX10/xla_tensor_wrapper.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,13 @@ XLA_API OpaqueXLATensor*
401401
XLATensor_tf_OneHot(OpaqueXLATensor* indices, OpaqueXLATensor* on_value,
402402
OpaqueXLATensor* off_value, int64_t depth, int64_t axis);
403403
XLA_API OpaqueXLATensor* XLATensor_tf_StatelessRandomNormal(
404-
Int64ArrayRef size, OpaqueXLATensor* seeds, const struct CDevice device,
405-
enum XLATensorScalarType type);
404+
Int64ArrayRef size, OpaqueXLATensor* seeds, enum XLATensorScalarType type);
406405
XLA_API OpaqueXLATensor* XLATensor_tf_StatelessRandomUniform(
407406
Int64ArrayRef size, OpaqueXLATensor* seeds, OpaqueXLATensor* minvalue,
408407
OpaqueXLATensor* maxvalue);
409408
XLA_API OpaqueXLATensor*
410409
XLATensor_tf_UnsortedSegmentSum(OpaqueXLATensor* data, OpaqueXLATensor* indices,
411410
int64_t num_segments);
412-
XLA_API OpaqueXLATensor* XLATensor_threshold_backward(
413-
OpaqueXLATensor* grad_output, OpaqueXLATensor* input, float threshold);
414411
XLA_API OpaqueXLATensor* XLATensor_threshold(
415412
OpaqueXLATensor* input, OpaqueXLATensor* output, float threshold, float value);
416413
XLA_API OpaqueXLATensor* XLATensor_truncated_normal(OpaqueXLATensor* input);

‎Sources/x10/swift_bindings/RawOpsXLAGenerated.swift

+135-11
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ extension _RawXLA {
125125
return Tensor(_xlaHandle: XLATensor_atanh(input.xlaHandle))
126126
}
127127

128+
public static func broadcast_tensors<
129+
T: TensorFlowScalar
130+
>(
131+
_ lhs: Tensor<T>,
132+
_ rhs: Tensor<T>
133+
) -> (Tensor<T>, Tensor<T>) {
134+
defer { _fixLifetime(lhs) }
135+
defer { _fixLifetime(rhs) }
136+
checkSameDevice(lhs.device, rhs.device)
137+
checkSamePrecision(lhs, rhs)
138+
let tuple_output = XLATensor_broadcast_tensors(lhs.xlaHandle, rhs.xlaHandle)
139+
return (Tensor(_xlaHandle: tuple_output.x), Tensor(_xlaHandle: tuple_output.y))
140+
}
141+
128142
public static func concat<
129143
T: TensorFlowScalar
130144
>(
@@ -474,6 +488,17 @@ extension _RawXLA {
474488
return Tensor(_xlaHandle: XLATensor_logicalAnd(lhs.xlaHandle, rhs.xlaHandle))
475489
}
476490

491+
static func logicalCast<
492+
Srct: TensorFlowScalar,
493+
Dstt: TensorFlowScalar
494+
>(
495+
_ input: Tensor<Srct>,
496+
destType: ScalarType
497+
) -> Tensor<Dstt> {
498+
defer { _fixLifetime(input) }
499+
return Tensor(_xlaHandle: XLATensor_logical_cast(input.xlaHandle, destType))
500+
}
501+
477502
public static func logicalNot(
478503
_ input: Tensor<Bool>
479504
) -> Tensor<Bool> {
@@ -639,6 +664,16 @@ extension _RawXLA {
639664
}
640665
}
641666

667+
static func physicalCast<
668+
T: TensorFlowScalar
669+
>(
670+
_ input: Tensor<T>,
671+
destType: ScalarType
672+
) -> Tensor<T> {
673+
defer { _fixLifetime(input) }
674+
return Tensor(_xlaHandle: XLATensor_physical_cast(input.xlaHandle, destType))
675+
}
676+
642677
public static func pow<
643678
T: TensorFlowNumeric
644679
>(
@@ -990,7 +1025,7 @@ extension _RawXLA {
9901025
}
9911026

9921027
static func tf_MirrorPad<
993-
T: TensorFlowNumeric
1028+
T: TensorFlowScalar
9941029
>(
9951030
_ input: Tensor<T>,
9961031
_ padding: [Int64],
@@ -1023,20 +1058,86 @@ extension _RawXLA {
10231058
Ti: TensorFlowInteger,
10241059
T: TensorFlowScalar
10251060
>(
1026-
indices: Tensor<Ti>,
1027-
on_value: Tensor<T>,
1028-
off_value: Tensor<T>,
1029-
depth: Int64,
1030-
axis: Int64
1061+
_ indices: Tensor<Ti>,
1062+
_ onValue: Tensor<T>,
1063+
_ offValue: Tensor<T>,
1064+
_ depth: Int64,
1065+
_ axis: Int64
10311066
) -> Tensor<T> {
10321067
defer { _fixLifetime(indices) }
1033-
defer { _fixLifetime(on_value) }
1034-
defer { _fixLifetime(off_value) }
1035-
checkSameDevice(indices.device, on_value.device)
1036-
checkSameDevice(indices.device, off_value.device)
1068+
defer { _fixLifetime(onValue) }
1069+
defer { _fixLifetime(offValue) }
1070+
checkSameDevice(indices.device, onValue.device)
1071+
checkSameDevice(indices.device, offValue.device)
10371072
return Tensor(
10381073
_xlaHandle: XLATensor_tf_OneHot(
1039-
indices.xlaHandle, on_value.xlaHandle, off_value.xlaHandle, depth, axis))
1074+
indices.xlaHandle, onValue.xlaHandle, offValue.xlaHandle, depth, axis))
1075+
}
1076+
1077+
static func tf_StatelessRandomNormal<
1078+
T: TensorFlowScalar,
1079+
Ti: TensorFlowIndex
1080+
>(
1081+
_ shape: [Int64],
1082+
_ seeds: Tensor<Ti>,
1083+
dtype: ScalarType
1084+
) -> Tensor<T> {
1085+
defer { _fixLifetime(seeds) }
1086+
return shape.withArrayRef { shape in
1087+
return Tensor(_xlaHandle: XLATensor_tf_StatelessRandomNormal(shape, seeds.xlaHandle, dtype))
1088+
}
1089+
}
1090+
1091+
static func tf_StatelessRandomUniform<
1092+
T: TensorFlowScalar,
1093+
Ti: TensorFlowIndex
1094+
>(
1095+
_ shape: [Int64],
1096+
_ seeds: Tensor<Ti>,
1097+
_ minvalue: Tensor<T>,
1098+
_ maxvalue: Tensor<T>
1099+
) -> Tensor<T> {
1100+
defer { _fixLifetime(seeds) }
1101+
defer { _fixLifetime(minvalue) }
1102+
defer { _fixLifetime(maxvalue) }
1103+
checkSameDevice(seeds.device, minvalue.device)
1104+
checkSameDevice(seeds.device, maxvalue.device)
1105+
return shape.withArrayRef { shape in
1106+
return Tensor(
1107+
_xlaHandle: XLATensor_tf_StatelessRandomUniform(
1108+
shape, seeds.xlaHandle, minvalue.xlaHandle, maxvalue.xlaHandle))
1109+
}
1110+
}
1111+
1112+
static func tf_UnsortedSegmentSum<
1113+
T: TensorFlowNumeric,
1114+
Ti: TensorFlowIndex
1115+
>(
1116+
_ data: Tensor<T>,
1117+
indicies: Tensor<Ti>,
1118+
numSegments: Int64
1119+
) -> Tensor<T> {
1120+
defer { _fixLifetime(data) }
1121+
defer { _fixLifetime(indicies) }
1122+
checkSameDevice(data.device, indicies.device)
1123+
return Tensor(
1124+
_xlaHandle: XLATensor_tf_UnsortedSegmentSum(data.xlaHandle, indicies.xlaHandle, numSegments))
1125+
}
1126+
1127+
static func threshold<
1128+
T: TensorFlowNumeric
1129+
>(
1130+
_ input: Tensor<T>,
1131+
output: Tensor<T>,
1132+
threshold: Float,
1133+
value: Float
1134+
) -> Tensor<T> {
1135+
defer { _fixLifetime(input) }
1136+
defer { _fixLifetime(output) }
1137+
checkSameDevice(input.device, output.device)
1138+
checkSamePrecision(input, output)
1139+
return Tensor(
1140+
_xlaHandle: XLATensor_threshold(input.xlaHandle, output.xlaHandle, threshold, value))
10401141
}
10411142

10421143
public static func topk<
@@ -1052,6 +1153,15 @@ extension _RawXLA {
10521153
return (Tensor(_xlaHandle: tuple_output.x), Tensor(_xlaHandle: tuple_output.y))
10531154
}
10541155

1156+
static func truncatedNormal<
1157+
T: FloatingPoint & TensorFlowScalar
1158+
>(
1159+
_ input: Tensor<T>
1160+
) -> Tensor<T> {
1161+
defer { _fixLifetime(input) }
1162+
return Tensor(_xlaHandle: XLATensor_truncated_normal(input.xlaHandle))
1163+
}
1164+
10551165
public static func updateSlice<
10561166
T: TensorFlowScalar
10571167
>(
@@ -1085,6 +1195,20 @@ extension _RawXLA {
10851195
_xlaHandle: XLATensor_where(condition.xlaHandle, input.xlaHandle, other.xlaHandle))
10861196
}
10871197

1198+
static func xlaPad<
1199+
T: TensorFlowScalar
1200+
>(
1201+
_ input: Tensor<T>,
1202+
paddingValue: AnyScalar,
1203+
paddingConfig: [PaddingConfigDimension]
1204+
) -> Tensor<T> {
1205+
defer { _fixLifetime(input) }
1206+
return paddingConfig.withArrayRef { paddingConfig in
1207+
return Tensor(
1208+
_xlaHandle: XLATensor_xla_pad(input.xlaHandle, paddingValue.xlaScalar, paddingConfig))
1209+
}
1210+
}
1211+
10881212
public static func xlaSlice<
10891213
T: TensorFlowScalar
10901214
>(

‎Sources/x10/swift_bindings/XLATensor.swift

+1-109
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ extension Array where Element: AnyTensor {
210210
}
211211

212212
extension Array where Element == PaddingConfigDimension {
213-
func withPaddingConfig<Result>(_ body: (inout PaddingConfig) -> Result) -> Result {
213+
func withArrayRef<Result>(_ body: (inout PaddingConfig) -> Result) -> Result {
214214
defer { _fixLifetime(self) }
215215
return withUnsafeBufferPointer {
216216
(_ dimensions: UnsafeBufferPointer<PaddingConfigDimension>) -> Result in
@@ -330,11 +330,6 @@ extension XLATensor {
330330
start.xlaScalar, stop.xlaScalar, num, cdevice, type))
331331
}
332332

333-
static func logicalCast(_ input: XLATensor, destType: XLATensorScalarType) -> XLATensor {
334-
defer { _fixLifetime(input) }
335-
return XLATensor(_handle: XLATensor_logical_cast(input.handle, destType))
336-
}
337-
338333
static func maxpool(
339334
_ input: XLATensor,
340335
_ ksize: [Int64],
@@ -368,104 +363,10 @@ extension XLATensor {
368363
}
369364
}
370365

371-
static func mirrorPad(_ input: XLATensor, _ padding: [Int64], _ mode: TFMirrorPadMode)
372-
-> XLATensor
373-
{
374-
defer { _fixLifetime(input) }
375-
return padding.withArrayRef { padding in
376-
XLATensor(_handle: XLATensor_tf_MirrorPad(input.handle, padding, mode))
377-
}
378-
}
379-
380-
static func mirrorPadGrad(
381-
_ grad_output: XLATensor, _ inputSize: [Int64], _ padding: [Int64], _ mode: TFMirrorPadMode
382-
)
383-
-> XLATensor
384-
{
385-
defer { _fixLifetime(grad_output) }
386-
return inputSize.withArrayRef { inputSize in
387-
padding.withArrayRef { padding in
388-
XLATensor(
389-
_handle: XLATensor_tf_MirrorPadGrad(grad_output.handle, inputSize, padding, mode))
390-
}
391-
}
392-
}
393-
394-
static func physicalCast(_ input: XLATensor, destType: XLATensorScalarType) -> XLATensor {
395-
defer { _fixLifetime(input) }
396-
return XLATensor(_handle: XLATensor_physical_cast(input.handle, destType))
397-
}
398-
399366
static func replica_id(_ device: Device) -> XLATensor {
400367
return XLATensor(_handle: XLATensor_replica_id(device.cdevice))
401368
}
402369

403-
static func tf_OneHot(
404-
_ indices: XLATensor, _ on_value: XLATensor, _ off_value: XLATensor, _ depth: Int64,
405-
_ axis: Int64
406-
) -> XLATensor {
407-
defer { _fixLifetime(indices) }
408-
defer { _fixLifetime(on_value) }
409-
defer { _fixLifetime(off_value) }
410-
return XLATensor(
411-
_handle: XLATensor_tf_OneHot(indices.handle, on_value.handle, off_value.handle, depth, axis))
412-
}
413-
414-
static func tf_StatelessRandomNormal(
415-
_ dims: [Int64],
416-
_ seeds: XLATensor,
417-
_ dtype: XLAScalarType.Type,
418-
_ device: Device
419-
) -> XLATensor {
420-
defer { _fixLifetime(seeds) }
421-
let cdevice = device.cdevice
422-
return dims.withArrayRef { dims in
423-
XLATensor(
424-
_handle: XLATensor_tf_StatelessRandomNormal(
425-
dims, seeds.handle, cdevice,
426-
dtype.xlaTensorScalarType))
427-
}
428-
}
429-
430-
static func tf_StatelessRandomUniform(
431-
_ dims: [Int64],
432-
_ seeds: XLATensor,
433-
_ minvalue: XLATensor,
434-
_ maxvalue: XLATensor,
435-
_ dtype: XLAScalarType.Type,
436-
_ device: Device
437-
) -> XLATensor {
438-
defer { _fixLifetime(seeds) }
439-
return dims.withArrayRef { dims in
440-
XLATensor(
441-
_handle: XLATensor_tf_StatelessRandomUniform(
442-
dims, seeds.handle, minvalue.handle, maxvalue.handle))
443-
}
444-
}
445-
446-
static func tf_UnsortedSegmentSum(
447-
_ data: XLATensor, _ indices: XLATensor, _ numSegments: Int64
448-
) -> XLATensor {
449-
defer { _fixLifetime(data) }
450-
defer { _fixLifetime(indices) }
451-
return XLATensor(
452-
_handle: XLATensor_tf_UnsortedSegmentSum(data.handle, indices.handle, numSegments))
453-
}
454-
455-
static func threshold_backward(_ grad_output: XLATensor, _ input: XLATensor, _ threshold: Float)
456-
-> XLATensor
457-
{
458-
defer { _fixLifetime(grad_output) }
459-
defer { _fixLifetime(input) }
460-
return XLATensor(
461-
_handle: XLATensor_threshold_backward(grad_output.handle, input.handle, threshold))
462-
}
463-
464-
static func truncatedNormal(_ input: XLATensor) -> XLATensor {
465-
defer { _fixLifetime(input) }
466-
return XLATensor(_handle: XLATensor_truncated_normal(input.handle))
467-
}
468-
469370
static func to(
470371
_ a: XLATensor, _ device: Device?, _ dtype: XLAScalarType.Type?
471372
) -> XLATensor {
@@ -477,15 +378,6 @@ extension XLATensor {
477378
}
478379
}
479380

480-
static func xlaPad(
481-
_ input: XLATensor, paddingValue: XLAScalarType, paddingConfig: [PaddingConfigDimension]
482-
) -> XLATensor {
483-
defer { _fixLifetime(input) }
484-
return paddingConfig.withPaddingConfig { paddingConfig in
485-
XLATensor(_handle: XLATensor_xla_pad(input.handle, paddingValue.xlaScalar, paddingConfig))
486-
}
487-
}
488-
489381
struct StridedSliceSpec {
490382
let begin: [Int64]
491383
let end: [Int64]

‎Sources/x10/swift_bindings/apis/RawOpsManual.swift

+22-47
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public enum _RawXLA {
2626
public typealias Padding1 = _RawTFEager.Padding1
2727
public typealias Mode5 = _RawTFEager.Mode1
2828
typealias AnyScalar = XLAScalarType
29+
typealias ScalarType = XLATensorScalarType
2930

3031
private static func canonicalDims(_ dims: [Int64], _ rank: Int64) -> [Int64] {
3132
dims.map { $0 < 0 ? $0 + rank : $0 }
@@ -541,7 +542,7 @@ public enum _RawXLA {
541542
_ x: Tensor<Srct>,
542543
truncate: Bool = false
543544
) -> Tensor<Dstt> {
544-
Tensor(_xla: XLATensor.logicalCast(x.xlaTensor, destType: Dstt.xlaTensorScalarType))
545+
return logicalCast(x, destType: Dstt.xlaTensorScalarType)
545546
}
546547

547548
/// Concatenates tensors along one dimension.
@@ -1916,9 +1917,7 @@ public enum _RawXLA {
19161917
mode: Mode5
19171918
) -> Tensor<T> {
19181919
let linearizedPaddings = paddings.scalars.map { Int64($0) }
1919-
return Tensor(
1920-
_xla: XLATensor.mirrorPad(
1921-
input.xlaTensor, reversedPaddings(linearizedPaddings), convertMirrorPadMode(mode)))
1920+
return tf_MirrorPad(input, reversedPaddings(linearizedPaddings), convertMirrorPadMode(mode))
19221921
}
19231922

19241923
/// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
@@ -1963,10 +1962,9 @@ public enum _RawXLA {
19631962
let totalPadding = linearizedPaddings[2 * dim] + linearizedPaddings[2 * dim + 1]
19641963
return Int64(grad.shape.dimensions[dim]) - totalPadding
19651964
}
1966-
return Tensor(
1967-
_xla: XLATensor.mirrorPadGrad(
1968-
grad.xlaTensor, inputDimensions, reversedPaddings(linearizedPaddings),
1969-
convertMirrorPadMode(mode)))
1965+
return tf_MirrorPadGrad(
1966+
grad, inputDimensions, reversedPaddings(linearizedPaddings),
1967+
convertMirrorPadMode(mode))
19701968
}
19711969

19721970
/// Returns the truth value of (x != y) element-wise.
@@ -2091,12 +2089,7 @@ public enum _RawXLA {
20912089
offValue: Tensor<T>,
20922090
axis: Int64 = -1
20932091
) -> Tensor<T> {
2094-
checkSameDevice(onValue, offValue)
2095-
checkSameDevice(indices.device, onValue.device)
2096-
checkSamePrecision(onValue, offValue)
2097-
return Tensor(
2098-
_xla: XLATensor.tf_OneHot(
2099-
indices.xlaTensor, onValue.xlaTensor, offValue.xlaTensor, depth, axis))
2092+
return tf_OneHot(indices, onValue, offValue, depth, axis)
21002093
}
21012094
public static func oneHot<
21022095
T: TensorFlowScalar,
@@ -2108,12 +2101,7 @@ public enum _RawXLA {
21082101
offValue: Tensor<T>,
21092102
axis: Int64 = -1
21102103
) -> Tensor<T> {
2111-
checkSameDevice(onValue, offValue)
2112-
checkSameDevice(indices.device, onValue.device)
2113-
checkSamePrecision(onValue, offValue)
2114-
return Tensor(
2115-
_xla: XLATensor.tf_OneHot(
2116-
indices.xlaTensor, onValue.xlaTensor, offValue.xlaTensor, Int64(depth.scalarized()), axis))
2104+
return tf_OneHot(indices, onValue, offValue, Int64(depth.scalarized()), axis)
21172105
}
21182106

21192107
/// Returns a tensor of ones with the same shape and type as x.
@@ -2250,7 +2238,7 @@ public enum _RawXLA {
22502238
public static func physicalCast<T: TensorFlowScalar, R: TensorFlowScalar>(
22512239
_ input: Tensor<T>, destType: R.Type
22522240
) -> Tensor<T> {
2253-
Tensor(_xla: XLATensor.physicalCast(input.xlaTensor, destType: destType.xlaTensorScalarType))
2241+
physicalCast(input, destType: destType.xlaTensorScalarType)
22542242
}
22552243

22562244
/// Computes the product of elements across dimensions of a tensor.
@@ -2379,9 +2367,7 @@ public enum _RawXLA {
23792367
gradients: Tensor<T>,
23802368
features: Tensor<T>
23812369
) -> Tensor<T> {
2382-
checkSameDevice(gradients, features)
2383-
checkSamePrecision(gradients, features)
2384-
return Tensor(_xla: XLATensor.threshold_backward(gradients.xlaTensor, features.xlaTensor, 0))
2370+
return threshold(features, output: gradients, threshold: 0, value: 0)
23852371
}
23862372

23872373
public static func replicaId(_ device: Device) -> Tensor<Int32> {
@@ -3138,10 +3124,8 @@ public enum _RawXLA {
31383124
seed: Tensor<Tseed>,
31393125
device: Device
31403126
) -> Tensor<Dtype> {
3141-
Tensor(
3142-
_xla: XLATensor.tf_StatelessRandomNormal(
3143-
shape.scalars.map { Int64($0) },
3144-
seed.xlaTensor, Dtype.self, device))
3127+
tf_StatelessRandomNormal(
3128+
shape.scalars.map { Int64($0) }, seed, dtype: Dtype.xlaTensorScalarType)
31453129
}
31463130

31473131
public static func statelessRandomNormal<
@@ -3179,11 +3163,10 @@ public enum _RawXLA {
31793163
seed: Tensor<Tseed>,
31803164
device: Device
31813165
) -> Tensor<Dtype> {
3182-
Tensor(
3183-
_xla: XLATensor.tf_StatelessRandomUniform(
3184-
shape.scalars.map { Int64($0) },
3185-
seed.xlaTensor, Tensor<Dtype>(0, on: device).xlaTensor,
3186-
Tensor<Dtype>(1, on: device).xlaTensor, Dtype.self, device))
3166+
tf_StatelessRandomUniform(
3167+
shape.scalars.map { Int64($0) },
3168+
seed, Tensor<Dtype>(0, on: device),
3169+
Tensor<Dtype>(1, on: device))
31873170
}
31883171

31893172
public static func statelessRandomUniform<
@@ -3224,11 +3207,7 @@ public enum _RawXLA {
32243207
maxval: Tensor<Dtype>,
32253208
device: Device
32263209
) -> Tensor<Dtype> {
3227-
Tensor(
3228-
_xla: XLATensor.tf_StatelessRandomUniform(
3229-
shape.scalars.map { Int64($0) },
3230-
seed.xlaTensor, minval.xlaTensor,
3231-
maxval.xlaTensor, Dtype.self, device))
3210+
tf_StatelessRandomUniform(shape.scalars.map { Int64($0) }, seed, minval, maxval)
32323211
}
32333212

32343213
public static func statelessRandomUniformInt<
@@ -3272,11 +3251,10 @@ public enum _RawXLA {
32723251
) -> Tensor<Dtype> {
32733252
let minval = Tensor<Dtype>(Dtype.leastNormalMagnitude, on: device)
32743253
let maxval = Tensor<Dtype>(1, on: device)
3275-
let uniform = XLATensor.tf_StatelessRandomUniform(
3254+
let uniform = tf_StatelessRandomUniform(
32763255
shape.scalars.map { Int64($0) },
3277-
seed.xlaTensor, minval.xlaTensor,
3278-
maxval.xlaTensor, Dtype.self, device)
3279-
return Tensor(_xla: XLATensor.truncatedNormal(uniform))
3256+
seed, minval, maxval)
3257+
return truncatedNormal(uniform)
32803258
}
32813259

32823260
public static func statelessTruncatedNormal<
@@ -3548,8 +3526,7 @@ public enum _RawXLA {
35483526
if !dimensionsToReverse.isEmpty {
35493527
grad = flip(grad, dims: dimensionsToReverse)
35503528
}
3551-
return Tensor(
3552-
_xla: XLATensor.xlaPad(grad.xlaTensor, paddingValue: 0, paddingConfig: paddingConfig))
3529+
return xlaPad(grad, paddingValue: 0, paddingConfig: paddingConfig)
35533530
}
35543531

35553532
/// Computes the sum of elements across dimensions of a tensor.
@@ -3819,9 +3796,7 @@ public enum _RawXLA {
38193796
+ String(segmentIds.shape.dimensions[dim]))
38203797
}
38213798
}
3822-
return Tensor(
3823-
_xla: XLATensor.tf_UnsortedSegmentSum(
3824-
data.xlaTensor, segmentIds.xlaTensor, Int64(numSegments)))
3799+
return tf_UnsortedSegmentSum(data, indicies: segmentIds, numSegments: Int64(numSegments))
38253800
}
38263801

38273802
/// Returns 0 if x == 0, and x / y otherwise, elementwise.

‎Sources/x10/swift_bindings/generate_ops.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"TFMirrorPadMode":
2424
("enum TFMirrorPadMode", lambda name: f"ToTFMirrorPadMode({name})",
2525
"tensorflow::MirrorPadMode"),
26-
"PaddingConfig":
26+
"[PaddingConfigDimension]":
2727
("PaddingConfig", lambda name: f"ToXLAPaddingConfig({name})",
2828
"xla::PaddingConfig"),
2929
}
@@ -229,6 +229,14 @@ def c_function_define(op):
229229
else:
230230
first_tensor = tensor_args[0][0]
231231

232+
def listify(l):
233+
if type(l) is list:
234+
return l
235+
return [l]
236+
237+
dtypes = (([None] * op["n_results"])
238+
if "result_dtype" not in op else listify(op["result_dtype"]))
239+
232240
def format_arg_def(arg):
233241
name, stype, _ = arg
234242
if stype == "Tensor": return "OpaqueXLATensor* " + name
@@ -253,8 +261,18 @@ def format_arg_ref(arg):
253261
if stype == "[Tensor]":
254262
return name + "_ir_value"
255263
if name == "shape":
264+
relement_type = f"{first_tensor}->shape().get().element_type()"
265+
result_dtype_arg = None
266+
if dtypes[0] and first_tensor != dtypes[0]:
267+
for arg in args:
268+
if arg[0] == dtypes[0]:
269+
result_dtype_arg = arg
270+
if result_dtype_arg:
271+
relement_type = (
272+
f"swift_xla::MakeXlaPrimitiveType({format_arg_ref(result_dtype_arg)},"
273+
f" /*device=*/nullptr)")
256274
return ("swift_xla::MakeArrayShapeFromDimensions(shape.slice(), {}, " +
257-
f"{first_tensor}->shape().get().element_type(), "
275+
f"{relement_type}, "
258276
f"{first_tensor}->GetDevice().hw_type)")
259277
if stype in builtin_types:
260278
return builtin_types[stype][1](name)
@@ -298,14 +316,6 @@ def unpack_arg(arg):
298316
f"""{op["c_name"]} has unsupported number of return values {op["n_results"]}"""
299317
)
300318

301-
def listify(l):
302-
if type(l) is list:
303-
return l
304-
return [l]
305-
306-
dtypes = (([None] * op["n_results"])
307-
if "result_dtype" not in op else listify(op["result_dtype"]))
308-
309319
def format_result(result_i=0, dtype=None):
310320
if not dtype:
311321
dtype = dtypes[result_i]

‎Sources/x10/swift_bindings/ops_list.txt

+41-11
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
lower_fn: xla::Atanh
6161
generics: {T: FloatingPoint & TensorFlowScalar}
6262

63+
- def: "broadcast_tensors(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> (Tensor<T>, Tensor<T>)"
64+
lower_fn: LowerBroadcastTensors
65+
generics: {T: TensorFlowScalar}
66+
6367
- def: "cat(_ input: [Tensor<T>], dim: Int64) -> Tensor<T>"
6468
extras: ["canonicalize dim input CanonicalizeCat"]
6569
swift_name: concat
@@ -235,11 +239,14 @@
235239
x10_enum: at::aten::logical_and
236240
lower_fn: LowerBinaryOp<xla::And>
237241

238-
- def: "logical_cast(_ input: Tensor, dtype: ScalarType) -> Tensor"
242+
- def: "logical_cast(_ input: Tensor<Srct>, destType: ScalarType) -> Tensor<Dstt>"
239243
x10_enum: xla_symbols::cast
244+
generics: {Srct: TensorFlowScalar, Dstt: TensorFlowScalar}
245+
swift_name: logicalCast
246+
protection: internal
240247
shape_fn: ShapeLogicalCast
241248
lower_fn: LowerLogicalCast
242-
result_dtype: dtype
249+
result_dtype: destType
243250

244251
- def: "logicalNot(_ input: Tensor<Bool>) -> Tensor<Bool>"
245252
x10_enum: at::aten::bitwise_not
@@ -314,7 +321,10 @@
314321
generics: {T: TensorFlowScalar}
315322
lower_fn: xla::Transpose
316323

317-
- def: "physical_cast(input: Tensor, dtype: ScalarType) -> Tensor"
324+
- def: "physical_cast(_ input: Tensor<T>, destType: ScalarType) -> Tensor<T>"
325+
generics: {T: TensorFlowScalar}
326+
swift_name: physicalCast
327+
protection: internal
318328
x10_enum: xla_symbols::cast
319329
shape_fn: ShapeLogicalCast
320330
lower_fn: LowerLogicalCast
@@ -458,7 +468,7 @@
458468

459469
- def: "tf_MirrorPad(_ input: Tensor<T>, _ padding: [Int64], _ mode: TFMirrorPadMode) -> Tensor<T>"
460470
x10_enum: at::aten::tf_mirror_pad
461-
generics: {T: TensorFlowNumeric}
471+
generics: {T: TensorFlowScalar}
462472
protection: internal
463473
lower_fn: BuildMirrorPad
464474

@@ -468,25 +478,39 @@
468478
protection: internal
469479
lower_fn: BuildMirrorPadBackward
470480

471-
- def: "tf_OneHot(indices: Tensor<Ti>, on_value: Tensor<T>, off_value: Tensor<T>, depth: Int64, axis: Int64) -> Tensor<T>"
472-
result_dtype: on_value
481+
- def: "tf_OneHot(_ indices: Tensor<Ti>, _ onValue: Tensor<T>, _ offValue: Tensor<T>, _ depth: Int64, _ axis: Int64) -> Tensor<T>"
482+
result_dtype: onValue
473483
generics: {Ti: TensorFlowInteger, T: TensorFlowScalar}
474484
protection: internal
475485
x10_enum: at::aten::tf_one_hot
476486
lower_fn: BuildOneHot
477487

478-
- def: "tf_StatelessRandomUniform(shape: [Int64], seeds: Tensor, minvalue: Tensor, maxvalue: Tensor) -> Tensor"
488+
- def: "tf_StatelessRandomNormal(_ shape: [Int64], _ seeds: Tensor<Ti>, dtype: ScalarType) -> Tensor<T>"
489+
x10_enum: at::aten::tf_stateless_random_normal
490+
generics: {T: TensorFlowScalar, Ti: TensorFlowIndex}
491+
extras: ["shape_fn shape", "needs_lowering_context"]
492+
lower_fn: LowerTfStatelessRandomNormal
493+
protection: internal
494+
result_dtype: dtype
495+
496+
- def: "tf_StatelessRandomUniform(_ shape: [Int64], _ seeds: Tensor<Ti>, _ minvalue: Tensor<T>, _ maxvalue: Tensor<T>) -> Tensor<T>"
479497
x10_enum: at::aten::tf_stateless_random_uniform
498+
generics: {T: TensorFlowScalar, Ti: TensorFlowIndex}
480499
extras: ["shape_fn shape", "needs_lowering_context"]
481500
lower_fn: LowerTfStatelessRandomUniform
501+
protection: internal
482502
result_dtype: minvalue
483503

484-
- def: "tf_UnsortedSegmentSum(data: Tensor, indicies: Tensor, num_segments: Int64) -> Tensor"
504+
- def: "tf_UnsortedSegmentSum(_ data: Tensor<T>, indicies: Tensor<Ti>, numSegments: Int64) -> Tensor<T>"
505+
generics: {T: TensorFlowNumeric, Ti: TensorFlowIndex}
485506
x10_enum: at::aten::tf_unsorted_segment_sum
507+
protection: internal
486508
lower_fn: LowerTfUnsortedSegmentSum
487509

488-
- def: "threshold(input: Tensor, output: Tensor, threshold: Float, value: Float) -> Tensor"
510+
- def: "threshold(_ input: Tensor<T>, output: Tensor<T>, threshold: Float, value: Float) -> Tensor<T>"
489511
x10_enum: at::aten::threshold_backward
512+
generics: {T: TensorFlowNumeric}
513+
protection: internal
490514
shape_fn: input
491515
lower_fn: BuildThreshold
492516

@@ -496,8 +520,11 @@
496520
generics: {T: FloatingPoint & TensorFlowScalar}
497521
result_dtype: [input, Long]
498522

499-
- def: "truncated_normal(input: Tensor) -> Tensor"
523+
- def: "truncated_normal(_ input: Tensor<T>) -> Tensor<T>"
500524
x10_enum: at::aten::xla_truncated_normal
525+
generics: {T: FloatingPoint & TensorFlowScalar}
526+
protection: internal
527+
swift_name: truncatedNormal
501528
shape_fn: input
502529
lower_fn: tensorflow::TruncatedNormal
503530

@@ -513,7 +540,10 @@
513540
swift_name: where_
514541
lower_fn: LowerWhere
515542

516-
- def: "xla_pad(input: Tensor, padding_value: AnyScalar, padding_config: PaddingConfig) -> Tensor"
543+
- def: "xla_pad(_ input: Tensor<T>, paddingValue: AnyScalar, paddingConfig: [PaddingConfigDimension]) -> Tensor<T>"
544+
generics: {T: TensorFlowScalar}
545+
protection: internal
546+
swift_name: xlaPad
517547
lower_fn: LowerPad
518548

519549
- def: "xla_slice(_ input: Tensor<T>, start_indices: [Int64], limit_indices: [Int64], strides: [Int64]) -> Tensor<T>"

0 commit comments

Comments
 (0)
This repository has been archived.