File tree 2 files changed +24
-27
lines changed
2 files changed +24
-27
lines changed Original file line number Diff line number Diff line change @@ -156,24 +156,22 @@ public func withRecomputationInPullbacks<T, U>(
156
156
}
157
157
}
158
158
159
- // FIXME: The method variant produces a zero cotangent. Need to investigate.
160
- //
161
- // public extension Differentiable {
162
- // @inlinable
163
- // @differentiable(wrt: self, vjp: _vjp_withRecomputationInPullbacks)
164
- // func withRecomputationInPullbacks<Result : Differentiable>(
165
- // _ body: @escaping @differentiable (Self) -> Result
166
- // ) -> Result {
167
- // return body(self)
168
- // }
169
- //
170
- // @usableFromInline
171
- // internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
172
- // _ body: @escaping @differentiable (Self) -> Result
173
- // ) -> (Result, (Result.CotangentVector) -> CotangentVector) {
174
- // return valueWithPullback(in: Swift.withRecomputationInPullbacks(body))
175
- // }
176
- // }
159
+ public extension Differentiable {
160
+ @inlinable
161
+ @differentiable ( wrt: self , vjp: _vjp_withRecomputationInPullbacks)
162
+ func withRecomputationInPullbacks< Result : Differentiable > (
163
+ _ body: @escaping @differentiable ( Self ) -> Result
164
+ ) -> Result {
165
+ return body ( self )
166
+ }
167
+
168
+ @usableFromInline
169
+ internal func _vjp_withRecomputationInPullbacks< Result : Differentiable > (
170
+ _ body: @escaping @differentiable ( Self ) -> Result
171
+ ) -> ( Result , ( Result . CotangentVector ) -> CotangentVector ) {
172
+ return valueWithPullback ( in: Swift . withRecomputationInPullbacks ( body) )
173
+ }
174
+ }
177
175
178
176
//===----------------------------------------------------------------------===//
179
177
// Method-style differential operators
Original file line number Diff line number Diff line change @@ -60,15 +60,14 @@ CustomDerivativesTests.test("Checkpointing") {
60
60
} )
61
61
expectEqual ( 2 , count)
62
62
// Reset and test the method variant.
63
- // FIXME: The method variant produces a zero cotangent. Need to investigate.
64
- // count = 0
65
- // expectEqual(324, gradient(at: 3) { (x: Float) -> Float in
66
- // expectEqual(0, count)
67
- // let y = x.withRecomputationInPullbacks(f)
68
- // expectEqual(1, count)
69
- // return y * 3 * x
70
- // })
71
- // expectEqual(2, count)
63
+ count = 0
64
+ expectEqual ( 324 , gradient ( at: 3 ) { ( x: Float ) -> Float in
65
+ expectEqual ( 0 , count)
66
+ let y = x. withRecomputationInPullbacks ( f)
67
+ expectEqual ( 1 , count)
68
+ return y * 3 * x
69
+ } )
70
+ expectEqual ( 2 , count)
72
71
}
73
72
74
73
runAllTests ( )
You can’t perform that action at this time.
0 commit comments