@@ -26,6 +26,7 @@ public enum _RawXLA {
26
26
public typealias Padding1 = _RawTFEager . Padding1
27
27
public typealias Mode5 = _RawTFEager . Mode1
28
28
typealias AnyScalar = XLAScalarType
29
+ typealias ScalarType = XLATensorScalarType
29
30
30
31
private static func canonicalDims( _ dims: [ Int64 ] , _ rank: Int64 ) -> [ Int64 ] {
31
32
dims. map { $0 < 0 ? $0 + rank : $0 }
@@ -541,7 +542,7 @@ public enum _RawXLA {
541
542
_ x: Tensor < Srct > ,
542
543
truncate: Bool = false
543
544
) -> Tensor < Dstt > {
544
- Tensor ( _xla : XLATensor . logicalCast ( x. xlaTensor , destType: Dstt . xlaTensorScalarType) )
545
+ return logicalCast ( x, destType: Dstt . xlaTensorScalarType)
545
546
}
546
547
547
548
/// Concatenates tensors along one dimension.
@@ -1916,9 +1917,7 @@ public enum _RawXLA {
1916
1917
mode: Mode5
1917
1918
) -> Tensor < T > {
1918
1919
let linearizedPaddings = paddings. scalars. map { Int64 ( $0) }
1919
- return Tensor (
1920
- _xla: XLATensor . mirrorPad (
1921
- input. xlaTensor, reversedPaddings ( linearizedPaddings) , convertMirrorPadMode ( mode) ) )
1920
+ return tf_MirrorPad ( input, reversedPaddings ( linearizedPaddings) , convertMirrorPadMode ( mode) )
1922
1921
}
1923
1922
1924
1923
/// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
@@ -1963,10 +1962,9 @@ public enum _RawXLA {
1963
1962
let totalPadding = linearizedPaddings [ 2 * dim] + linearizedPaddings[ 2 * dim + 1 ]
1964
1963
return Int64 ( grad. shape. dimensions [ dim] ) - totalPadding
1965
1964
}
1966
- return Tensor (
1967
- _xla: XLATensor . mirrorPadGrad (
1968
- grad. xlaTensor, inputDimensions, reversedPaddings ( linearizedPaddings) ,
1969
- convertMirrorPadMode ( mode) ) )
1965
+ return tf_MirrorPadGrad (
1966
+ grad, inputDimensions, reversedPaddings ( linearizedPaddings) ,
1967
+ convertMirrorPadMode ( mode) )
1970
1968
}
1971
1969
1972
1970
/// Returns the truth value of (x != y) element-wise.
@@ -2091,12 +2089,7 @@ public enum _RawXLA {
2091
2089
offValue: Tensor < T > ,
2092
2090
axis: Int64 = - 1
2093
2091
) -> Tensor < T > {
2094
- checkSameDevice ( onValue, offValue)
2095
- checkSameDevice ( indices. device, onValue. device)
2096
- checkSamePrecision ( onValue, offValue)
2097
- return Tensor (
2098
- _xla: XLATensor . tf_OneHot (
2099
- indices. xlaTensor, onValue. xlaTensor, offValue. xlaTensor, depth, axis) )
2092
+ return tf_OneHot ( indices, onValue, offValue, depth, axis)
2100
2093
}
2101
2094
public static func oneHot<
2102
2095
T: TensorFlowScalar ,
@@ -2108,12 +2101,7 @@ public enum _RawXLA {
2108
2101
offValue: Tensor < T > ,
2109
2102
axis: Int64 = - 1
2110
2103
) -> Tensor < T > {
2111
- checkSameDevice ( onValue, offValue)
2112
- checkSameDevice ( indices. device, onValue. device)
2113
- checkSamePrecision ( onValue, offValue)
2114
- return Tensor (
2115
- _xla: XLATensor . tf_OneHot (
2116
- indices. xlaTensor, onValue. xlaTensor, offValue. xlaTensor, Int64 ( depth. scalarized ( ) ) , axis) )
2104
+ return tf_OneHot ( indices, onValue, offValue, Int64 ( depth. scalarized ( ) ) , axis)
2117
2105
}
2118
2106
2119
2107
/// Returns a tensor of ones with the same shape and type as x.
@@ -2250,7 +2238,7 @@ public enum _RawXLA {
2250
2238
public static func physicalCast< T: TensorFlowScalar , R: TensorFlowScalar > (
2251
2239
_ input: Tensor < T > , destType: R . Type
2252
2240
) -> Tensor < T > {
2253
- Tensor ( _xla : XLATensor . physicalCast ( input. xlaTensor , destType: destType. xlaTensorScalarType) )
2241
+ physicalCast ( input, destType: destType. xlaTensorScalarType)
2254
2242
}
2255
2243
2256
2244
/// Computes the product of elements across dimensions of a tensor.
@@ -2379,9 +2367,7 @@ public enum _RawXLA {
2379
2367
gradients: Tensor < T > ,
2380
2368
features: Tensor < T >
2381
2369
) -> Tensor < T > {
2382
- checkSameDevice ( gradients, features)
2383
- checkSamePrecision ( gradients, features)
2384
- return Tensor ( _xla: XLATensor . threshold_backward ( gradients. xlaTensor, features. xlaTensor, 0 ) )
2370
+ return threshold ( features, output: gradients, threshold: 0 , value: 0 )
2385
2371
}
2386
2372
2387
2373
public static func replicaId( _ device: Device ) -> Tensor < Int32 > {
@@ -3138,10 +3124,8 @@ public enum _RawXLA {
3138
3124
seed: Tensor < Tseed > ,
3139
3125
device: Device
3140
3126
) -> Tensor < Dtype > {
3141
- Tensor (
3142
- _xla: XLATensor . tf_StatelessRandomNormal (
3143
- shape. scalars. map { Int64 ( $0) } ,
3144
- seed. xlaTensor, Dtype . self, device) )
3127
+ tf_StatelessRandomNormal (
3128
+ shape. scalars. map { Int64 ( $0) } , seed, dtype: Dtype . xlaTensorScalarType)
3145
3129
}
3146
3130
3147
3131
public static func statelessRandomNormal<
@@ -3179,11 +3163,10 @@ public enum _RawXLA {
3179
3163
seed: Tensor < Tseed > ,
3180
3164
device: Device
3181
3165
) -> Tensor < Dtype > {
3182
- Tensor (
3183
- _xla: XLATensor . tf_StatelessRandomUniform (
3184
- shape. scalars. map { Int64 ( $0) } ,
3185
- seed. xlaTensor, Tensor < Dtype > ( 0 , on: device) . xlaTensor,
3186
- Tensor < Dtype > ( 1 , on: device) . xlaTensor, Dtype . self, device) )
3166
+ tf_StatelessRandomUniform (
3167
+ shape. scalars. map { Int64 ( $0) } ,
3168
+ seed, Tensor < Dtype > ( 0 , on: device) ,
3169
+ Tensor < Dtype > ( 1 , on: device) )
3187
3170
}
3188
3171
3189
3172
public static func statelessRandomUniform<
@@ -3224,11 +3207,7 @@ public enum _RawXLA {
3224
3207
maxval: Tensor < Dtype > ,
3225
3208
device: Device
3226
3209
) -> Tensor < Dtype > {
3227
- Tensor (
3228
- _xla: XLATensor . tf_StatelessRandomUniform (
3229
- shape. scalars. map { Int64 ( $0) } ,
3230
- seed. xlaTensor, minval. xlaTensor,
3231
- maxval. xlaTensor, Dtype . self, device) )
3210
+ tf_StatelessRandomUniform ( shape. scalars. map { Int64 ( $0) } , seed, minval, maxval)
3232
3211
}
3233
3212
3234
3213
public static func statelessRandomUniformInt<
@@ -3272,11 +3251,10 @@ public enum _RawXLA {
3272
3251
) -> Tensor < Dtype > {
3273
3252
let minval = Tensor < Dtype > ( Dtype . leastNormalMagnitude, on: device)
3274
3253
let maxval = Tensor < Dtype > ( 1 , on: device)
3275
- let uniform = XLATensor . tf_StatelessRandomUniform (
3254
+ let uniform = tf_StatelessRandomUniform (
3276
3255
shape. scalars. map { Int64 ( $0) } ,
3277
- seed. xlaTensor, minval. xlaTensor,
3278
- maxval. xlaTensor, Dtype . self, device)
3279
- return Tensor ( _xla: XLATensor . truncatedNormal ( uniform) )
3256
+ seed, minval, maxval)
3257
+ return truncatedNormal ( uniform)
3280
3258
}
3281
3259
3282
3260
public static func statelessTruncatedNormal<
@@ -3548,8 +3526,7 @@ public enum _RawXLA {
3548
3526
if !dimensionsToReverse. isEmpty {
3549
3527
grad = flip ( grad, dims: dimensionsToReverse)
3550
3528
}
3551
- return Tensor (
3552
- _xla: XLATensor . xlaPad ( grad. xlaTensor, paddingValue: 0 , paddingConfig: paddingConfig) )
3529
+ return xlaPad ( grad, paddingValue: 0 , paddingConfig: paddingConfig)
3553
3530
}
3554
3531
3555
3532
/// Computes the sum of elements across dimensions of a tensor.
@@ -3819,9 +3796,7 @@ public enum _RawXLA {
3819
3796
+ String( segmentIds. shape. dimensions [ dim] ) )
3820
3797
}
3821
3798
}
3822
- return Tensor (
3823
- _xla: XLATensor . tf_UnsortedSegmentSum (
3824
- data. xlaTensor, segmentIds. xlaTensor, Int64 ( numSegments) ) )
3799
+ return tf_UnsortedSegmentSum ( data, indicies: segmentIds, numSegments: Int64 ( numSegments) )
3825
3800
}
3826
3801
3827
3802
/// Returns 0 if x == 0, and x / y otherwise, elementwise.
0 commit comments