Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b7148bc

Browse files
authored
Revert non-active try-apply differentiation workarounds (#1069)
Since TF-433 was fixed, we can remove all the workarkounds for (non-active) try-apply. For example, calls to `Array.map(_:)`.
1 parent 77f7885 commit b7148bc

File tree

1 file changed

+15
-30
lines changed

1 file changed

+15
-30
lines changed

Sources/TensorFlow/Operators/Math.swift

+15-30
Original file line numberDiff line numberDiff line change
@@ -1860,8 +1860,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18601860
@inlinable
18611861
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18621862
public func max(squeezingAxes axes: [Int]) -> Tensor {
1863-
// TODO(TF-433): Remove workaround for differentiating `map`.
1864-
let axes = { axes.map(Int32.init) }()
1863+
let axes = axes.map(Int32.init)
18651864
return max(squeezingAxes: Tensor<Int32>(axes, on: device))
18661865
}
18671866

@@ -1890,8 +1889,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18901889
@inlinable
18911890
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18921891
public func min(squeezingAxes axes: [Int]) -> Tensor {
1893-
// TODO(TF-433): Remove workaround for differentiating `map`.
1894-
let axes = { axes.map(Int32.init) }()
1892+
let axes = axes.map(Int32.init)
18951893
return min(squeezingAxes: Tensor<Int32>(axes, on: device))
18961894
}
18971895

@@ -1942,8 +1940,7 @@ extension Tensor where Scalar: Numeric & Comparable {
19421940
@inlinable
19431941
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
19441942
public func min(alongAxes axes: [Int]) -> Tensor {
1945-
// TODO(TF-433): Remove workaround for differentiating `map`.
1946-
let axes = { axes.map(Int32.init) }()
1943+
let axes = axes.map(Int32.init)
19471944
return min(alongAxes: Tensor<Int32>(axes, on: device))
19481945
}
19491946

@@ -1975,8 +1972,7 @@ extension Tensor where Scalar: Numeric & Comparable {
19751972
@inlinable
19761973
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
19771974
public func max(alongAxes axes: [Int]) -> Tensor {
1978-
// TODO(TF-433): Remove workaround for differentiating `map`.
1979-
let axes = { axes.map(Int32.init) }()
1975+
let axes = axes.map(Int32.init)
19801976
return max(alongAxes: Tensor<Int32>(axes, on: device))
19811977
}
19821978

@@ -2114,8 +2110,7 @@ extension Tensor where Scalar: Numeric {
21142110
@inlinable
21152111
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21162112
public func sum(squeezingAxes axes: [Int]) -> Tensor {
2117-
// TODO(TF-433): Remove workaround for differentiating `map`.
2118-
let axes = { axes.map(Int64.init) }()
2113+
let axes = axes.map(Int64.init)
21192114
return _Raw.sum(self, reductionIndices: axes, keepDims: false)
21202115
}
21212116

@@ -2150,8 +2145,7 @@ extension Tensor where Scalar: Numeric {
21502145
@inlinable
21512146
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21522147
public func sum(alongAxes axes: [Int]) -> Tensor {
2153-
// TODO(TF-433): Remove workaround for differentiating `map`.
2154-
let axes = { axes.map(Int64.init) }()
2148+
let axes = axes.map(Int64.init)
21552149
return _Raw.sum(self, reductionIndices: axes, keepDims: true)
21562150
}
21572151

@@ -2184,8 +2178,7 @@ extension Tensor where Scalar: Numeric {
21842178
@inlinable
21852179
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21862180
public func product(squeezingAxes axes: [Int]) -> Tensor {
2187-
// TODO(TF-433): Remove workaround for differentiating `map`.
2188-
let axes = { axes.map(Int32.init) }()
2181+
let axes = axes.map(Int32.init)
21892182
return product(squeezingAxes: Tensor<Int32>(axes, on: device))
21902183
}
21912184

@@ -2221,8 +2214,7 @@ extension Tensor where Scalar: Numeric {
22212214
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
22222215
@inlinable
22232216
public func product(alongAxes axes: [Int]) -> Tensor {
2224-
// TODO(TF-433): Remove workaround for differentiating `map`.
2225-
let axes = { axes.map(Int32.init) }()
2217+
let axes = axes.map(Int32.init)
22262218
return product(alongAxes: Tensor<Int32>(axes, on: device))
22272219
}
22282220

@@ -2253,8 +2245,7 @@ extension Tensor where Scalar: Numeric {
22532245
@inlinable
22542246
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
22552247
public func mean(squeezingAxes axes: [Int]) -> Tensor {
2256-
// TODO(TF-433): Remove workaround for differentiating `map`.
2257-
let axes = { axes.map(Int64.init) }()
2248+
let axes = axes.map(Int64.init)
22582249
return _Raw.mean(self, reductionIndices: axes, keepDims: false)
22592250
}
22602251

@@ -2291,8 +2282,7 @@ extension Tensor where Scalar: Numeric {
22912282
@inlinable
22922283
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
22932284
public func mean(alongAxes axes: [Int]) -> Tensor {
2294-
// TODO(TF-433): Remove workaround for differentiating `map`.
2295-
let axes = { axes.map(Int64.init) }()
2285+
let axes = axes.map(Int64.init)
22962286
return _Raw.mean(self, reductionIndices: axes, keepDims: true)
22972287
}
22982288

@@ -2327,8 +2317,7 @@ extension Tensor where Scalar: Numeric {
23272317
@inlinable
23282318
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
23292319
public func variance(squeezingAxes axes: [Int]) -> Tensor {
2330-
// TODO(TF-433): Remove workaround for differentiating `map`.
2331-
let axes = { axes.map(Int32.init) }()
2320+
let axes = axes.map(Int32.init)
23322321
return variance(squeezingAxes: Tensor<Int32>(axes, on: device))
23332322
}
23342323

@@ -2369,8 +2358,7 @@ extension Tensor where Scalar: Numeric {
23692358
@inlinable
23702359
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
23712360
public func variance(alongAxes axes: [Int]) -> Tensor {
2372-
// TODO(TF-433): Remove workaround for differentiating `map`.
2373-
let axes = { axes.map(Int32.init) }()
2361+
let axes = axes.map(Int32.init)
23742362
return variance(alongAxes: Tensor<Int32>(axes, on: device))
23752363
}
23762364

@@ -2791,8 +2779,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
27912779
@inlinable
27922780
@differentiable(wrt: self)
27932781
public func standardDeviation(alongAxes axes: [Int]) -> Tensor {
2794-
// TODO(TF-433): Remove workaround for differentiating `map`.
2795-
let axes = { axes.map(Int32.init) }()
2782+
let axes = axes.map(Int32.init)
27962783
return standardDeviation(alongAxes: Tensor<Int32>(axes, on: device))
27972784
}
27982785

@@ -2842,8 +2829,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
28422829
@inlinable
28432830
@differentiable(wrt: self)
28442831
public func logSumExp(squeezingAxes axes: [Int]) -> Tensor {
2845-
// TODO(TF-433): Remove workaround for differentiating `map`.
2846-
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
2832+
let axes = axes.map(Int32.init)
28472833
return logSumExp(squeezingAxes: Tensor<Int32>(axes, on: device))
28482834
}
28492835

@@ -2907,8 +2893,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
29072893
@inlinable
29082894
@differentiable(wrt: self)
29092895
public func logSumExp(alongAxes axes: [Int]) -> Tensor {
2910-
// TODO(TF-433): Remove workaround for differentiating `map`.
2911-
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
2896+
let axes = axes.map(Int32.init)
29122897
return logSumExp(alongAxes: Tensor<Int32>(axes, on: device))
29132898
}
29142899

0 commit comments

Comments
 (0)