Skip to content

Commit 82f3f5e

Browse files
authored
[AutoDiff] Enable some checkpointing APIs that became unblocked. (#22668)
This completes the set of checkpointing APIs introduced in #22033.
1 parent 56b5e96 commit 82f3f5e

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

stdlib/public/core/AutoDiff.swift

+16-18
Original file line numberDiff line numberDiff line change
@@ -156,24 +156,22 @@ public func withRecomputationInPullbacks<T, U>(
156156
}
157157
}
158158

159-
// FIXME: The method variant produces a zero cotangent. Need to investigate.
160-
//
161-
// public extension Differentiable {
162-
// @inlinable
163-
// @differentiable(wrt: self, vjp: _vjp_withRecomputationInPullbacks)
164-
// func withRecomputationInPullbacks<Result : Differentiable>(
165-
// _ body: @escaping @differentiable (Self) -> Result
166-
// ) -> Result {
167-
// return body(self)
168-
// }
169-
//
170-
// @usableFromInline
171-
// internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
172-
// _ body: @escaping @differentiable (Self) -> Result
173-
// ) -> (Result, (Result.CotangentVector) -> CotangentVector) {
174-
// return valueWithPullback(in: Swift.withRecomputationInPullbacks(body))
175-
// }
176-
// }
159+
public extension Differentiable {
160+
@inlinable
161+
@differentiable(wrt: self, vjp: _vjp_withRecomputationInPullbacks)
162+
func withRecomputationInPullbacks<Result : Differentiable>(
163+
_ body: @escaping @differentiable (Self) -> Result
164+
) -> Result {
165+
return body(self)
166+
}
167+
168+
@usableFromInline
169+
internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
170+
_ body: @escaping @differentiable (Self) -> Result
171+
) -> (Result, (Result.CotangentVector) -> CotangentVector) {
172+
return valueWithPullback(in: Swift.withRecomputationInPullbacks(body))
173+
}
174+
}
177175

178176
//===----------------------------------------------------------------------===//
179177
// Method-style differential operators

test/AutoDiff/custom_derivatives.swift

+8-9
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,14 @@ CustomDerivativesTests.test("Checkpointing") {
6060
})
6161
expectEqual(2, count)
6262
// Reset and test the method variant.
63-
// FIXME: The method variant produces a zero cotangent. Need to investigate.
64-
// count = 0
65-
// expectEqual(324, gradient(at: 3) { (x: Float) -> Float in
66-
// expectEqual(0, count)
67-
// let y = x.withRecomputationInPullbacks(f)
68-
// expectEqual(1, count)
69-
// return y * 3 * x
70-
// })
71-
// expectEqual(2, count)
63+
count = 0
64+
expectEqual(324, gradient(at: 3) { (x: Float) -> Float in
65+
expectEqual(0, count)
66+
let y = x.withRecomputationInPullbacks(f)
67+
expectEqual(1, count)
68+
return y * 3 * x
69+
})
70+
expectEqual(2, count)
7271
}
7372

7473
runAllTests()

0 commit comments

Comments
 (0)