Skip to content

Commit dfb6f20

Browse files
committed
[Docs] [AutoDiff] Allow implicitly inherited '@differentiable' on non-public declarations.
Currently, when a conforming type implements a `@differentiable` protocol requirement, the corresponding conforming implemnetation is required to have at least a `@differentiable` that covers all differentiability parameters in the protocol requirement. However, this is not a great design for usability because developers almost always start with missing `@differentiable` and getting an compilation error. This also makes ML models built with libraries that use differentiable programming more verbose than those built with other ML frameworks. We agreed during this Friday's design review to allow `@differentiable` to be implicitly inherited from protocols when the conforming implementation is non-public.
1 parent 1b0eee9 commit dfb6f20

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

docs/DifferentiableProgramming.md

+7-6
Original file line numberDiff line numberDiff line change
@@ -1376,11 +1376,12 @@ inheritance must maintain the differentiability.
13761376

13771377
The `@differentiable` attribute can be used on protocol requirements. A
13781378
`@differentiable` protocol requirement requires that all conforming types
1379-
implement this protocol requirement with a differentiable body with respect to
1380-
the specified parameters.
1379+
implement this requirement with a differentiable body with respect to the
1380+
specified parameters. Conforming implementations are not required to be marked
1381+
with `@differentiable` attribute unless they are `public`.
13811382

13821383
```swift
1383-
protocol Layer: Differentiable {
1384+
public protocol Layer: Differentiable {
13841385
associatedtype Input: Differentiable
13851386
associatedtype Output: Differentiable
13861387
@differentiable // w.r.t. `input` and `self`
@@ -1389,7 +1390,7 @@ protocol Layer: Differentiable {
13891390
struct Perceptron: @memberwise Differentiable, Layer {
13901391
var weight: SIMD4<Float>
13911392
var bias: Float
1392-
@differentiable // w.r.t. `input` and `self`
1393+
13931394
func callAsFunction(_ input: SIMD4<Float>) -> Float {
13941395
(weight * input).sum() + b
13951396
}
@@ -1401,14 +1402,14 @@ with a `@differentiable` attribute that declares differentiability with respect
14011402
to more parameters.
14021403

14031404
```swift
1404-
protocol Module: Differentiable {
1405+
public protocol Module: Differentiable {
14051406
associatedtype Input
14061407
associatedtype Output: Differentiable
14071408
@differentiable(wrt: self)
14081409
func callAsFunction(_: Input) -> Output
14091410
}
14101411

1411-
protocol Layer: Module where Input: Differentiable {
1412+
public protocol Layer: Module where Input: Differentiable {
14121413
@differentiable(wrt: (self, input))
14131414
func callAsFunction(_: Input) -> Output
14141415
}

0 commit comments

Comments
 (0)