Skip to content

Commit 43e6ad6

Browse files
eaplataniosrxwei
authored andcommitted
[AutoDiff] Enables conditional differentiability on protocol requirements. (#25974)
This PR enables the functionality shown in this example: ```swift public protocol Distribution { associatedtype Value func logProbability(of value: Value) -> Float } public protocol DifferentiableDistribution: Differentiable, Distribution { @differentiable(wrt: self) func logProbability(of value: Value) -> Float } struct Foo: DifferentiableDistribution { @differentiable(wrt: self) func logProbability(of value: Float) -> Float { .zero } } @differentiable func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic { x.logProbability(of: .zero) } ``` The fix is based on the fact that the Swift compiler does not add entries to the witness tables of protocols for overridden functions, to avoid redundancy. However, the `@differentiable` attribute being added should not be interpreted as an override as it adds new functionality, and should result in entries being added to the witness table. This PR adds this check to the override checking code, thus enabling support for the aforementioned feature.
1 parent 6b3d874 commit 43e6ad6

8 files changed

+252
-39
lines changed

Diff for: include/swift/AST/DiagnosticsSema.def

+2
Original file line numberDiff line numberDiff line change
@@ -2784,6 +2784,8 @@ ERROR(differentiable_attr_invalid_access,none,
27842784
ERROR(differentiable_attr_result_not_differentiable,none,
27852785
"can only differentiate functions with results that conform to "
27862786
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2787+
ERROR(differentiable_attr_protocol_where_clause,none,
2788+
"'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement", ())
27872789
ERROR(differentiable_attr_empty_where_clause,none,
27882790
"empty 'where' clause in '@differentiable' attribute", ())
27892791
ERROR(differentiable_attr_nongeneric_trailing_where,none,

Diff for: lib/AST/Attr.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,16 @@ static void printDifferentiableAttrArguments(
499499
stream << "vjp: " << vjp->Name;
500500
}
501501
// Print 'where' clause, if any.
502-
if (!attr->getRequirements().empty()) {
502+
// First, filter out requirements satisfied by the original function's
503+
// generic signature. They should not be printed.
504+
auto requirementsToPrint =
505+
makeFilterRange(attr->getRequirements(), [&](Requirement req) {
506+
if (auto *originalGenSig = original->getGenericSignature())
507+
if (originalGenSig->isRequirementSatisfied(req))
508+
return false;
509+
return true;
510+
});
511+
if (!requirementsToPrint.empty()) {
503512
if (!isLeadingClause)
504513
stream << ' ';
505514
stream << "where ";
@@ -515,15 +524,6 @@ static void printDifferentiableAttrArguments(
515524
return genericEnv->getSugaredType(Ty);
516525
};
517526
}
518-
// Filter out requirements satisfied by original function's generic
519-
// signature. They should not be printed.
520-
auto requirementsToPrint =
521-
makeFilterRange(attr->getRequirements(), [&](Requirement req) {
522-
if (auto *originalGenSig = original->getGenericSignature())
523-
if (originalGenSig->isRequirementSatisfied(req))
524-
return false;
525-
return true;
526-
});
527527
interleave(requirementsToPrint, [&](Requirement req) {
528528
if (auto *originalGenSig = original->getGenericSignature())
529529
if (originalGenSig->isRequirementSatisfied(req))

Diff for: lib/Sema/TypeCheckAttr.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -3276,6 +3276,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
32763276
GenericSignature *whereClauseGenSig = nullptr;
32773277
GenericEnvironment *whereClauseGenEnv = nullptr;
32783278
if (auto whereClause = attr->getWhereClause()) {
3279+
// 'where' clauses in '@differentiable' attributes of protocol
3280+
// requirements are not supported.
3281+
if (isa<ProtocolDecl>(original->getDeclContext()) &&
3282+
original->isProtocolRequirement()) {
3283+
TC.diagnose(attr->getLocation(),
3284+
diag::differentiable_attr_protocol_where_clause);
3285+
attr->setInvalid();
3286+
return;
3287+
}
32793288
if (whereClause->getRequirements().empty()) {
32803289
// Where clause must not be empty.
32813290
TC.diagnose(attr->getLocation(),

Diff for: lib/Sema/TypeCheckDeclOverride.cpp

+92
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,92 @@ static bool parameterTypesMatch(const ValueDecl *derivedDecl,
591591
return true;
592592
}
593593

594+
// SWIFT_ENABLE_TENSORFLOW
595+
static bool overridesDifferentiableAttribute(ValueDecl *derivedDecl,
596+
ValueDecl *baseDecl) {
597+
ASTContext &ctx = derivedDecl->getASTContext();
598+
auto &diags = ctx.Diags;
599+
600+
auto *derivedAFD = dyn_cast<AbstractFunctionDecl>(derivedDecl);
601+
auto *baseAFD = dyn_cast<AbstractFunctionDecl>(baseDecl);
602+
603+
if (!derivedAFD || !baseAFD)
604+
return false;
605+
606+
auto derivedDAs = derivedAFD->getAttrs().getAttributes<DifferentiableAttr>();
607+
auto baseDAs = baseAFD->getAttrs().getAttributes<DifferentiableAttr>();
608+
609+
// Make sure all the differentiable attributes in `baseDecl` are
610+
// also declared in `derivedDecl`.
611+
for (auto baseDA : baseDAs) {
612+
auto baseParameters = baseDA->getParameterIndices();
613+
auto defined = false;
614+
for (auto derivedDA : derivedDAs) {
615+
auto derivedParameters = derivedDA->getParameterIndices();
616+
if (derivedParameters &&
617+
baseParameters &&
618+
AutoDiffIndexSubset::get(
619+
ctx, baseParameters->parameters)
620+
->isSubsetOf(AutoDiffIndexSubset::get(
621+
ctx, derivedParameters->parameters))) {
622+
defined = true;
623+
break;
624+
}
625+
}
626+
if (!defined) {
627+
// Omit printing wrt clause if attribute differentiation parameters match
628+
// inferred differentiation parameters.
629+
auto *inferredParameters = TypeChecker::inferDifferentiableParameters(
630+
derivedAFD, nullptr);
631+
bool omitWrtClause = !baseParameters ||
632+
baseParameters->parameters.count() ==
633+
inferredParameters->parameters.count();
634+
// Get `@differentiable` attribute description.
635+
std::string baseDAString;
636+
llvm::raw_string_ostream stream(baseDAString);
637+
baseDA->print(stream, derivedDecl, omitWrtClause);
638+
diags.diagnose(
639+
derivedDecl,
640+
diag::protocol_witness_missing_differentiable_attr,
641+
StringRef(stream.str()).trim());
642+
return false;
643+
}
644+
}
645+
646+
// If there is no differentiable attribute in `derivedDecl`, then
647+
// overriding is not allowed.
648+
if (derivedDAs.empty())
649+
return false;
650+
651+
// Finally, go through all differentiable attributes in
652+
// `derivedDecl` and check if they subsume any of the
653+
// differentiable attributes in `baseDecl`.
654+
for (auto derivedDA : derivedDAs) {
655+
auto derivedParameters = derivedDA->getParameterIndices();
656+
auto overrides = true;
657+
for (auto baseDA : baseDAs) {
658+
auto baseParameters = baseDA->getParameterIndices();
659+
// If the differentiable indices of `derivedDA` are a
660+
// subset of those of `baseDA`, then `baseDA` subsumes
661+
// `derivedDA` and the function is marked as overridden.
662+
if (derivedParameters &&
663+
baseParameters &&
664+
AutoDiffIndexSubset::get(
665+
ctx, derivedParameters->parameters)
666+
->isSubsetOf(AutoDiffIndexSubset::get(
667+
ctx, baseParameters->parameters))) {
668+
overrides = false;
669+
break;
670+
}
671+
}
672+
if (overrides)
673+
return true;
674+
}
675+
676+
return false;
677+
}
678+
// SWIFT_ENABLE_TENSORFLOW END
679+
594680
/// Returns true if the given declaration is for the `NSObject.hashValue`
595681
/// property.
596682
static bool isNSObjectHashValue(ValueDecl *baseDecl) {
@@ -746,6 +832,12 @@ SmallVector<OverrideMatch, 2> OverrideMatcher::match(
746832
if (!areOverrideCompatibleSimple(decl, parentDecl))
747833
continue;
748834

835+
// SWIFT_ENABLE_TENSORFLOW
836+
// Check whether the differentiable attribute allows overriding.
837+
if (overridesDifferentiableAttribute(decl, parentDecl))
838+
continue;
839+
// SWIFT_ENABLE_TENSORFLOW END
840+
749841
auto parentMethod = dyn_cast<AbstractFunctionDecl>(parentDecl);
750842
auto parentStorage = dyn_cast<AbstractStorageDecl>(parentDecl);
751843
assert(parentMethod || parentStorage);

Diff for: lib/Sema/TypeCheckProtocol.cpp

+42-8
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,11 @@ swift::matchWitness(
569569
// SWIFT_ENABLE_TENSORFLOW
570570
auto result = finalize(anyRenaming, optionalAdjustments);
571571
if (result.isViable()) {
572-
// '@differentiable' attributes must match completely.
572+
// '@differentiable' attributes must match completely. If there exists a
573+
// '@differentiable' attribute with a superset of the "wrt" parameters of
574+
// a requirement, then an '@differentiable' attribute is added
575+
// automatically.
576+
ASTContext &ctx = witness->getASTContext();
573577
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
574578
auto witnessDiffAttrs = witnessAttrs
575579
.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
@@ -579,14 +583,44 @@ swift::matchWitness(
579583
reqDiffAttr->getParameterIndices() &&
580584
witnessDiffAttr->parametersMatch(*reqDiffAttr);
581585
});
586+
bool reqDiffAttrSupersetMatch = llvm::any_of(
587+
witnessDiffAttrs, [&](const DifferentiableAttr *witnessDiffAttr) {
588+
return witnessDiffAttr->getParameterIndices() &&
589+
reqDiffAttr->getParameterIndices() &&
590+
AutoDiffIndexSubset::get(
591+
ctx, witnessDiffAttr->getParameterIndices()->parameters)
592+
->isSupersetOf(AutoDiffIndexSubset::get(
593+
ctx, reqDiffAttr->getParameterIndices()->parameters));
594+
});
582595
if (!reqDiffAttrMatch) {
583-
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
584-
return RequirementMatch(
585-
getStandinForAccessor(vdWitness, AccessorKind::Get),
586-
MatchKind::DifferentiableConflict, reqDiffAttr);
587-
else
588-
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
589-
reqDiffAttr);
596+
auto implicitDiffAttr = false;
597+
if (reqDiffAttrSupersetMatch) {
598+
auto *newAttr = DifferentiableAttr::create(
599+
ctx, /*implicit*/ true, reqDiffAttr->AtLoc,
600+
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
601+
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
602+
/*vjp*/ None, reqDiffAttr->getRequirements());
603+
auto insertion = ctx.DifferentiableAttrs.try_emplace(
604+
{witness, newAttr->getParameterIndices()}, newAttr);
605+
// Valid `@differentiable` attributes are uniqued by their parameter
606+
// indices. Reject duplicate attributes for the same decl and parameter
607+
// indices pair.
608+
if (!insertion.second) {
609+
newAttr->setInvalid();
610+
} else {
611+
witness->getAttrs().add(newAttr);
612+
implicitDiffAttr = true;
613+
}
614+
}
615+
if (!implicitDiffAttr) {
616+
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
617+
return RequirementMatch(
618+
getStandinForAccessor(vdWitness, AccessorKind::Get),
619+
MatchKind::DifferentiableConflict, reqDiffAttr);
620+
else
621+
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
622+
reqDiffAttr);
623+
}
590624
}
591625
}
592626
}

Diff for: stdlib/public/core/AutoDiff.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public extension VectorProtocol {
100100
}
101101
}
102102

103-
/* Note: These default-implemented opreators will slow down type-checking
103+
/* Note: These default-implemented operators will slow down type-checking
104104
performance and break existing code.
105105

106106
public extension VectorProtocol {

Diff for: test/AutoDiff/differentiable_attr_type_checking.swift

+40-20
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,6 @@ extension FloatingPoint {
588588
}
589589
}
590590

591-
protocol MethodDiffReq {
592-
@differentiable(wrt: self, vjp: vjpFoo where Self : Differentiable)
593-
func foo() -> Self
594-
}
595-
596-
extension MethodDiffReq where Self : Differentiable {
597-
func vjpFoo(x: Self) -> (Self, (Self.TangentVector) -> Self.TangentVector) {
598-
return (self, { $0 })
599-
}
600-
}
601-
602591
// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}}
603592
@differentiable(wrt: x, vjp: vjpNonvariadic)
604593
func variadic(_ x: Float, indices: Int32...) -> Float {
@@ -647,10 +636,6 @@ protocol DifferentiableAttrRequirements : Differentiable {
647636
// expected-note @+2 {{protocol requires function 'f2'}}
648637
@differentiable(wrt: (self, x, y))
649638
func f2(_ x: Float, _ y: Float) -> Float
650-
651-
// expected-note @+2 {{protocol requires function 'generic'}}
652-
@differentiable(where T : Differentiable)
653-
func generic<T>(_ x: T) -> T
654639
}
655640

656641
// expected-error @+1 {{does not conform to protocol 'DifferentiableAttrRequirements'}}
@@ -696,11 +681,6 @@ struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
696681
func f2(_ x: Float, _ y: Float) -> Float {
697682
return x + y
698683
}
699-
700-
// expected-note @+1 {{candidate is missing attribute '@differentiable(where T : Differentiable)'}}
701-
func generic<T>(_ x: T) -> T {
702-
return x
703-
}
704684
}
705685

706686
protocol NotRefiningDiffable {
@@ -865,3 +845,43 @@ func inout1(x: Float, y: inout Float) -> Void {
865845
func inout2(x: Float, y: inout Float) -> Float {
866846
let _ = x + y
867847
}
848+
849+
850+
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
851+
852+
protocol DifferentiableWhereClause: Differentiable {
853+
associatedtype Scalar
854+
855+
@differentiable(where Scalar: Differentiable) // expected-error {{'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement}}
856+
func test(value: Scalar) -> Float
857+
}
858+
859+
// Missing a `@differentiable` attribute.
860+
861+
public protocol Distribution {
862+
associatedtype Value
863+
func logProbability(of value: Value) -> Float
864+
}
865+
866+
public protocol DifferentiableDistribution: Differentiable, Distribution {
867+
@differentiable(wrt: self)
868+
func logProbability(of value: Value) -> Float
869+
}
870+
871+
public protocol MissingDifferentiableDistribution: DifferentiableDistribution
872+
where Value: Differentiable {
873+
func logProbability(of value: Value) -> Float // expected-note {{candidate is missing attribute '@differentiable(wrt: self)'}}
874+
}
875+
876+
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
877+
878+
protocol Example: Differentiable {
879+
associatedtype Scalar: Differentiable
880+
881+
@differentiable
882+
func test(value: Scalar) -> Float
883+
}
884+
885+
protocol MissingDifferentiableTest: Example {
886+
func test(value: Scalar) -> Float // expected-note {{candidate is missing attribute '@differentiable'}}
887+
}

Diff for: test/AutoDiff/protocol_requirement_autodiff.swift

+56
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,60 @@ struct S : P {
137137
}
138138
}
139139

140+
// MARK: - Overridden protocol method adding differentiable attribute.
141+
142+
public protocol Distribution {
143+
associatedtype Value
144+
func logProbability(of value: Value) -> Float
145+
}
146+
147+
public protocol DifferentiableDistribution: Differentiable, Distribution {
148+
@differentiable(wrt: self)
149+
func logProbability(of value: Value) -> Float
150+
}
151+
152+
struct Foo: DifferentiableDistribution {
153+
@differentiable(wrt: self)
154+
func logProbability(of value: Float) -> Float {
155+
.zero
156+
}
157+
}
158+
159+
@differentiable
160+
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
161+
x.logProbability(of: .zero)
162+
}
163+
164+
// Adding a more general `@differentiable` attribute.
165+
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
166+
where Value: Differentiable {
167+
@differentiable(wrt: self)
168+
@differentiable(wrt: (self, value))
169+
func logProbability(of value: Value) -> Float
170+
}
171+
172+
@differentiable
173+
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Float
174+
where T.Value: AdditiveArithmetic {
175+
x.logProbability(of: value)
176+
}
177+
178+
protocol DifferentiableFoo {
179+
associatedtype T: Differentiable
180+
@differentiable(wrt: x)
181+
func foo(_ x: T) -> Float
182+
}
183+
184+
protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
185+
@differentiable(wrt: (self, x))
186+
func foo(_ x: T) -> Float
187+
}
188+
189+
struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
190+
@differentiable(wrt: (self, x))
191+
func foo(_ x: Float) -> Float {
192+
x
193+
}
194+
}
195+
140196
runAllTests()

0 commit comments

Comments
 (0)