@@ -1860,8 +1860,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1860
1860
@inlinable
1861
1861
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1862
1862
public func max( squeezingAxes axes: [ Int ] ) -> Tensor {
1863
- // TODO(TF-433): Remove workaround for differentiating `map`.
1864
- let axes = { axes. map ( Int32 . init) } ( )
1863
+ let axes = axes. map ( Int32 . init)
1865
1864
return max ( squeezingAxes: Tensor < Int32 > ( axes, on: device) )
1866
1865
}
1867
1866
@@ -1890,8 +1889,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1890
1889
@inlinable
1891
1890
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1892
1891
public func min( squeezingAxes axes: [ Int ] ) -> Tensor {
1893
- // TODO(TF-433): Remove workaround for differentiating `map`.
1894
- let axes = { axes. map ( Int32 . init) } ( )
1892
+ let axes = axes. map ( Int32 . init)
1895
1893
return min ( squeezingAxes: Tensor < Int32 > ( axes, on: device) )
1896
1894
}
1897
1895
@@ -1942,8 +1940,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1942
1940
@inlinable
1943
1941
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1944
1942
public func min( alongAxes axes: [ Int ] ) -> Tensor {
1945
- // TODO(TF-433): Remove workaround for differentiating `map`.
1946
- let axes = { axes. map ( Int32 . init) } ( )
1943
+ let axes = axes. map ( Int32 . init)
1947
1944
return min ( alongAxes: Tensor < Int32 > ( axes, on: device) )
1948
1945
}
1949
1946
@@ -1975,8 +1972,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1975
1972
@inlinable
1976
1973
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1977
1974
public func max( alongAxes axes: [ Int ] ) -> Tensor {
1978
- // TODO(TF-433): Remove workaround for differentiating `map`.
1979
- let axes = { axes. map ( Int32 . init) } ( )
1975
+ let axes = axes. map ( Int32 . init)
1980
1976
return max ( alongAxes: Tensor < Int32 > ( axes, on: device) )
1981
1977
}
1982
1978
@@ -2114,8 +2110,7 @@ extension Tensor where Scalar: Numeric {
2114
2110
@inlinable
2115
2111
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2116
2112
public func sum( squeezingAxes axes: [ Int ] ) -> Tensor {
2117
- // TODO(TF-433): Remove workaround for differentiating `map`.
2118
- let axes = { axes. map ( Int64 . init) } ( )
2113
+ let axes = axes. map ( Int64 . init)
2119
2114
return _Raw. sum ( self , reductionIndices: axes, keepDims: false )
2120
2115
}
2121
2116
@@ -2150,8 +2145,7 @@ extension Tensor where Scalar: Numeric {
2150
2145
@inlinable
2151
2146
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2152
2147
public func sum( alongAxes axes: [ Int ] ) -> Tensor {
2153
- // TODO(TF-433): Remove workaround for differentiating `map`.
2154
- let axes = { axes. map ( Int64 . init) } ( )
2148
+ let axes = axes. map ( Int64 . init)
2155
2149
return _Raw. sum ( self , reductionIndices: axes, keepDims: true )
2156
2150
}
2157
2151
@@ -2184,8 +2178,7 @@ extension Tensor where Scalar: Numeric {
2184
2178
@inlinable
2185
2179
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2186
2180
public func product( squeezingAxes axes: [ Int ] ) -> Tensor {
2187
- // TODO(TF-433): Remove workaround for differentiating `map`.
2188
- let axes = { axes. map ( Int32 . init) } ( )
2181
+ let axes = axes. map ( Int32 . init)
2189
2182
return product ( squeezingAxes: Tensor < Int32 > ( axes, on: device) )
2190
2183
}
2191
2184
@@ -2221,8 +2214,7 @@ extension Tensor where Scalar: Numeric {
2221
2214
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2222
2215
@inlinable
2223
2216
public func product( alongAxes axes: [ Int ] ) -> Tensor {
2224
- // TODO(TF-433): Remove workaround for differentiating `map`.
2225
- let axes = { axes. map ( Int32 . init) } ( )
2217
+ let axes = axes. map ( Int32 . init)
2226
2218
return product ( alongAxes: Tensor < Int32 > ( axes, on: device) )
2227
2219
}
2228
2220
@@ -2253,8 +2245,7 @@ extension Tensor where Scalar: Numeric {
2253
2245
@inlinable
2254
2246
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2255
2247
public func mean( squeezingAxes axes: [ Int ] ) -> Tensor {
2256
- // TODO(TF-433): Remove workaround for differentiating `map`.
2257
- let axes = { axes. map ( Int64 . init) } ( )
2248
+ let axes = axes. map ( Int64 . init)
2258
2249
return _Raw. mean ( self , reductionIndices: axes, keepDims: false )
2259
2250
}
2260
2251
@@ -2291,8 +2282,7 @@ extension Tensor where Scalar: Numeric {
2291
2282
@inlinable
2292
2283
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2293
2284
public func mean( alongAxes axes: [ Int ] ) -> Tensor {
2294
- // TODO(TF-433): Remove workaround for differentiating `map`.
2295
- let axes = { axes. map ( Int64 . init) } ( )
2285
+ let axes = axes. map ( Int64 . init)
2296
2286
return _Raw. mean ( self , reductionIndices: axes, keepDims: true )
2297
2287
}
2298
2288
@@ -2327,8 +2317,7 @@ extension Tensor where Scalar: Numeric {
2327
2317
@inlinable
2328
2318
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2329
2319
public func variance( squeezingAxes axes: [ Int ] ) -> Tensor {
2330
- // TODO(TF-433): Remove workaround for differentiating `map`.
2331
- let axes = { axes. map ( Int32 . init) } ( )
2320
+ let axes = axes. map ( Int32 . init)
2332
2321
return variance ( squeezingAxes: Tensor < Int32 > ( axes, on: device) )
2333
2322
}
2334
2323
@@ -2369,8 +2358,7 @@ extension Tensor where Scalar: Numeric {
2369
2358
@inlinable
2370
2359
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2371
2360
public func variance( alongAxes axes: [ Int ] ) -> Tensor {
2372
- // TODO(TF-433): Remove workaround for differentiating `map`.
2373
- let axes = { axes. map ( Int32 . init) } ( )
2361
+ let axes = axes. map ( Int32 . init)
2374
2362
return variance ( alongAxes: Tensor < Int32 > ( axes, on: device) )
2375
2363
}
2376
2364
@@ -2791,8 +2779,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2791
2779
@inlinable
2792
2780
@differentiable ( wrt: self )
2793
2781
public func standardDeviation( alongAxes axes: [ Int ] ) -> Tensor {
2794
- // TODO(TF-433): Remove workaround for differentiating `map`.
2795
- let axes = { axes. map ( Int32 . init) } ( )
2782
+ let axes = axes. map ( Int32 . init)
2796
2783
return standardDeviation ( alongAxes: Tensor < Int32 > ( axes, on: device) )
2797
2784
}
2798
2785
@@ -2842,8 +2829,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2842
2829
@inlinable
2843
2830
@differentiable ( wrt: self )
2844
2831
public func logSumExp( squeezingAxes axes: [ Int ] ) -> Tensor {
2845
- // TODO(TF-433): Remove workaround for differentiating `map`.
2846
- let axes = withoutDerivative ( at: axes) { $0. map ( Int32 . init) }
2832
+ let axes = axes. map ( Int32 . init)
2847
2833
return logSumExp ( squeezingAxes: Tensor < Int32 > ( axes, on: device) )
2848
2834
}
2849
2835
@@ -2907,8 +2893,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2907
2893
@inlinable
2908
2894
@differentiable ( wrt: self )
2909
2895
public func logSumExp( alongAxes axes: [ Int ] ) -> Tensor {
2910
- // TODO(TF-433): Remove workaround for differentiating `map`.
2911
- let axes = withoutDerivative ( at: axes) { $0. map ( Int32 . init) }
2896
+ let axes = axes. map ( Int32 . init)
2912
2897
return logSumExp ( alongAxes: Tensor < Int32 > ( axes, on: device) )
2913
2898
}
2914
2899
0 commit comments