Skip to content

Commit d214a91

Browse files
authored
[Sema] Differentiable conformance derivation for class types. (#25914)
`Differentiable` derived conformances now supports class types. Synthesis works just like for struct types, except `TangentVector = Self` is never synthesized even if `Self` conforms to `AdditiveArithmetic`. Class differentiation support requires further differentiation transform changes. Resolves TF-630.
1 parent dd2067a commit d214a91

10 files changed

+650
-38
lines changed

Diff for: include/swift/AST/DiagnosticsSema.def

+3-3
Original file line numberDiff line numberDiff line change
@@ -2869,9 +2869,9 @@ ERROR(compiler_evaluable_ref_non_compiler_evaluable,none,
28692869
"@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ())
28702870

28712871
// @noDerivative attribute
2872-
ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
2873-
"'@noDerivative' is only allowed on stored properties in structure types "
2874-
"that declare a conformance to 'Differentiable'", ())
2872+
ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
2873+
"'@noDerivative' is only allowed on stored properties in structure or "
2874+
"class types that declare a conformance to 'Differentiable'", ())
28752875

28762876
//------------------------------------------------------------------------------
28772877
// MARK: Type Check Expressions

Diff for: lib/Sema/DerivedConformanceDifferentiable.cpp

+13-16
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// SWIFT_ENABLE_TENSORFLOW
1414
//
1515
// This file implements explicit derivation of the Differentiable protocol for
16-
// struct types.
16+
// struct and class types.
1717
//
1818
//===----------------------------------------------------------------------===//
1919

@@ -108,8 +108,7 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) {
108108
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
109109
assert(diffableProto && "`Differentiable` protocol not found");
110110
auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(),
111-
diffableProto,
112-
DC, None);
111+
diffableProto, DC, None);
113112
assert(conf && "Nominal must conform to `Differentiable`");
114113
Type assocType = conf->getTypeWitnessByName(DC->getSelfTypeInContext(), id);
115114
assert(assocType && "`Differentiable` protocol associated type not found");
@@ -120,9 +119,8 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) {
120119

121120
bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
122121
DeclContext *DC) {
123-
// Nominal type must be a struct. (Zero stored properties is okay.)
124-
auto *structDecl = dyn_cast<StructDecl>(nominal);
125-
if (!structDecl)
122+
// Nominal type must be a struct or class. (No stored properties is okay.)
123+
if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))
126124
return false;
127125
auto &C = nominal->getASTContext();
128126
auto *lazyResolver = C.getLazyResolver();
@@ -153,8 +151,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
153151
// `X == X.TangentVector`.
154152
if (nominal->isImplicit() && structDecl == nominal->getDeclContext() &&
155153
TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(),
156-
diffableProto, DC,
157-
None))
154+
diffableProto, DC, None))
158155
return structDecl;
159156
// 3. Equal nominal (and conform to `AdditiveArithmetic` if flag is true).
160157
if (structDecl == nominal) {
@@ -199,7 +196,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
199196
// initializers that initialize all stored properties, including initial
200197
// value information.
201198
SmallVector<VarDecl *, 16> diffProperties;
202-
getStoredPropertiesForDifferentiation(structDecl, DC, diffProperties);
199+
getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
203200
return llvm::all_of(diffProperties, [&](VarDecl *v) {
204201
if (!v->hasInterfaceType())
205202
lazyResolver->resolveDeclSignature(v);
@@ -325,7 +322,8 @@ static ValueDecl *deriveDifferentiable_method(
325322
/*Throws*/ false, SourceLoc(),
326323
/*GenericParams=*/nullptr, params,
327324
TypeLoc::withoutLoc(returnType), parentDC);
328-
funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
325+
if (!nominal->getSelfClassDecl())
326+
funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
329327
funcDecl->setImplicit();
330328
funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
331329

@@ -804,6 +802,8 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
804802
DeclContext* DC) {
805803
auto *diffableProto =
806804
TC.Context.getProtocol(KnownProtocolKind::Differentiable);
805+
bool nominalCanDeriveAdditiveArithmetic =
806+
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
807807
for (auto *vd : nominal->getStoredProperties()) {
808808
if (!vd->hasInterfaceType())
809809
TC.resolveDeclSignature(vd);
@@ -814,8 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
814814
continue;
815815
// Check whether to diagnose stored property.
816816
bool conformsToDifferentiable =
817-
TC.conformsToProtocol(varType, diffableProto, nominal,
818-
None).hasValue();
817+
TC.conformsToProtocol(varType, diffableProto, nominal, None).hasValue();
819818
// If stored property should not be diagnosed, continue.
820819
if (conformsToDifferentiable && !vd->isLet())
821820
continue;
@@ -829,8 +828,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
829828
// `Differentiable` protocol requirements all have default implementations
830829
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
831830
// derived conformances will no longer be necessary.
832-
bool nominalCanDeriveAdditiveArithmetic =
833-
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
834831
if (!conformsToDifferentiable) {
835832
TC.diagnose(loc,
836833
diag::differentiable_nondiff_type_implicit_noderivative_fixit,
@@ -844,7 +841,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
844841
vd->getName(), nominal->getName(),
845842
nominalCanDeriveAdditiveArithmetic)
846843
.fixItInsert(loc, "@noDerivative ");
847-
848844
}
849845
}
850846

@@ -954,6 +950,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
954950
bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties;
955951

956952
// Check conditions for returning `Self`.
953+
// - `Self` is not a class type.
957954
// - No `@noDerivative` stored properties exist.
958955
// - All stored properties must have specified associated type equal to
959956
// `Self`.
@@ -971,7 +968,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
971968
parentDC, None);
972969

973970
// Return `Self` if conditions are met.
974-
if (!hasNoDerivativeStoredProp &&
971+
if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() &&
975972
(id == C.Id_AllDifferentiableVariables ||
976973
(allMembersAssocTypeEqualsSelf && nominalConformsToAddArith))) {
977974
auto selfType = parentDC->getSelfTypeInContext();

Diff for: lib/Sema/DerivedConformanceRingMathProtocols.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
178178
// If conformance reference is concrete, then use concrete witness
179179
// declaration for the operator.
180180
if (confRef->isConcrete())
181-
memberOpDecl = confRef->getConcrete()->getWitnessDecl(
182-
operatorReq, C.getLazyResolver());
181+
if (auto *concreteMemberMethodDecl =
182+
confRef->getConcrete()->getWitnessDecl(operatorReq,
183+
C.getLazyResolver()))
184+
memberOpDecl = concreteMemberMethodDecl;
183185
assert(memberOpDecl && "Member operator declaration must exist");
184186
auto memberOpDRE =
185187
new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true);

Diff for: lib/Sema/DerivedConformances.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,9 @@ DerivedConformance::declareDerivedPropertySetter(TypeChecker &tc,
554554
/*GenericParams*/ nullptr, params, TypeLoc(), parentDC);
555555
setterDecl->setImplicit();
556556
setterDecl->setStatic(isStatic);
557-
setterDecl->setSelfAccessKind(SelfAccessKind::Mutating);
557+
// Set mutating if parent is not a class.
558+
if (!parentDC->getSelfClassDecl())
559+
setterDecl->setSelfAccessKind(SelfAccessKind::Mutating);
558560

559561
// If this is supposed to be a final method, mark it as such.
560562
assert(isFinal || !parentDC->getSelfClassDecl());
@@ -584,6 +586,10 @@ DerivedConformance::declareDerivedProperty(Identifier name,
584586
VarDecl *propDecl = new (C) VarDecl(/*IsStatic*/isStatic, VarDecl::Specifier::Var,
585587
/*IsCaptureList*/false, SourceLoc(), name,
586588
parentDC);
589+
// SWIFT_ENABLE_TENSORFLOW
590+
// TODO: Upstream this change to master.
591+
if (isFinal && parentDC->getSelfClassDecl())
592+
propDecl->getAttrs().add(new (C) FinalAttr(/*Implicit*/ true));
587593
propDecl->setImplicit();
588594
propDecl->copyFormalAccessFrom(Nominal, /*sourceIsParentContext*/ true);
589595
propDecl->setInterfaceType(propertyInterfaceType);

Diff for: lib/Sema/TypeCheckAttr.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -3759,20 +3759,20 @@ void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
37593759
return;
37603760
if (!vd || vd->isStatic()) {
37613761
diagnoseAndRemoveAttr(attr,
3762-
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
3762+
diag::noderivative_only_on_differentiable_struct_or_class_fields);
37633763
return;
37643764
}
3765-
auto *structDecl = dyn_cast<StructDecl>(vd->getDeclContext());
3766-
if (!structDecl) {
3765+
auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl();
3766+
if (!nominal || (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))) {
37673767
diagnoseAndRemoveAttr(attr,
3768-
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
3768+
diag::noderivative_only_on_differentiable_struct_or_class_fields);
37693769
return;
37703770
}
37713771
if (!conformsToDifferentiable(
3772-
structDecl->getDeclaredInterfaceType(),
3773-
structDecl->getDeclContext())) {
3772+
nominal->getDeclaredInterfaceType(),
3773+
nominal->getDeclContext())) {
37743774
diagnoseAndRemoveAttr(attr,
3775-
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
3775+
diag::noderivative_only_on_differentiable_struct_or_class_fields);
37763776
return;
37773777
}
37783778
}

Diff for: test/AutoDiff/derived_differentiable_properties.swift renamed to test/AutoDiff/derived_differentiable.swift

+30-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public struct Foo : Differentiable {
1818
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
1919

2020
// CHECK-SILGEN-LABEL: // Foo.a.getter
21-
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float
21+
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
2222

2323
struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
2424
var a: Float
@@ -82,9 +82,6 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
8282
var x: T.TangentVector
8383
}
8484

85-
// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized.
86-
// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized.
87-
8885
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
8986
// CHECK-AST: internal var x: T.TangentVector
9087
// CHECK-AST: internal init(x: T.TangentVector)
@@ -105,3 +102,32 @@ extension ConditionallyDifferentiable : Differentiable where T : Differentiable
105102
// CHECK-AST: public var x: T
106103
// CHECK-AST: internal init(x: T)
107104
// CHECK-AST: }
105+
106+
// Verify that `TangentVector` is not synthesized to be `Self` for
107+
// `AdditiveArithmetic`-conforming classes.
108+
final class AdditiveArithmeticClass<T : AdditiveArithmetic & Differentiable> : AdditiveArithmetic, Differentiable {
109+
var x, y: T
110+
init(x: T, y: T) {
111+
self.x = x
112+
self.y = y
113+
}
114+
115+
// Dummy `AdditiveArithmetic` requirements.
116+
static func == (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Bool {
117+
fatalError()
118+
}
119+
static var zero: AdditiveArithmeticClass {
120+
fatalError()
121+
}
122+
static func + (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
123+
fatalError()
124+
}
125+
static func - (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
126+
fatalError()
127+
}
128+
}
129+
130+
// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
131+
// CHECK-AST: final internal var x: T, y: T
132+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
133+
// CHECK-AST: }

Diff for: test/AutoDiff/noderivative-attr.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: %target-swift-frontend -typecheck -verify %s
22

3-
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
3+
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
44
@noDerivative var flag: Bool
55

66
struct Foo {
7-
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
7+
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
88
@noDerivative var flag: Bool
99
}
1010

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
3+
// expected-note @+1 {{type declared here}}
4+
class OtherFileNonconforming {}
5+
6+
// expected-note @+1 {{type declared here}}
7+
class GenericOtherFileNonconforming<T : Differentiable> {
8+
var x: T
9+
}

0 commit comments

Comments
 (0)