Skip to content

Commit e5cb871

Browse files
authoredMar 25, 2020
[AutoDiff upstream] Add flag-gated AdditiveArithmetic derivation. (#30628)
Add `AdditiveArithmetic` derived conformances for structs, gated by the `-enable-experimential-additive-arithmetic-derivation` flag. Structs whose stored properties all conform to `AdditiveArithmetic` can derive `AdditiveArithmetic`: - `static var zero: Self` - `static func +(lhs: Self, rhs: Self) -> Self` - `static func -(lhs: Self, rhs: Self) -> Self` - An "effective memberwise initializer": - Either a synthesized memberwise initializer or a user-defined initializer with the same type. Effective memberwise initializers are used only by derived conformances for `Self`-returning protocol requirements like `AdditiveArithmetic.+`, which require memberwise initialization. Resolves TF-844. Unblocks TF-845: upstream `Differentiable` derived conformances.
1 parent 07596cb commit e5cb871

20 files changed

+681
-7
lines changed
 

‎include/swift/AST/Decl.h

+21
Original file line numberDiff line numberDiff line change
@@ -3516,6 +3516,27 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
35163516
/// or \c nullptr if it does not have one.
35173517
ConstructorDecl *getMemberwiseInitializer() const;
35183518

3519+
/// Retrieves the effective memberwise initializer for this declaration, or
3520+
/// \c nullptr if it does not have one.
3521+
///
3522+
/// An effective memberwise initializer is either a synthesized memberwise
3523+
/// initializer or a user-defined initializer with the same type.
3524+
///
3525+
/// The access level of the memberwise initializer is set to the minimum of:
3526+
/// - Public, by default. This enables public nominal types to have public
3527+
/// memberwise initializers.
3528+
/// - The `public` default is important for synthesized member types, e.g.
3529+
/// `TangentVector` structs synthesized during `Differentiable` derived
3530+
/// conformances. Manually extending these types to define a public
3531+
/// memberwise initializer causes a redeclaration error.
3532+
/// - The minimum access level of memberwise-initialized properties in the
3533+
/// nominal type declaration.
3534+
///
3535+
/// Effective memberwise initializers are used only by derived conformances
3536+
/// for `Self`-returning protocol requirements like `AdditiveArithmetic.+`.
3537+
/// Such derived conformances require memberwise initialization.
3538+
ConstructorDecl *getEffectiveMemberwiseInitializer();
3539+
35193540
/// Whether this declaration has a synthesized zero parameter default
35203541
/// initializer.
35213542
bool hasDefaultInitializer() const;

‎include/swift/AST/DiagnosticsSema.def

+2
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,8 @@ ERROR(cannot_synthesize_in_crossfile_extension,none,
26742674
"implementation of %0 cannot be automatically synthesized in an extension "
26752675
"in a different file to the type", (Type))
26762676

2677+
ERROR(broken_additive_arithmetic_requirement,none,
2678+
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
26772679
ERROR(broken_case_iterable_requirement,none,
26782680
"CaseIterable protocol is broken: unexpected requirement", ())
26792681
ERROR(broken_raw_representable_requirement,none,

‎include/swift/AST/KnownIdentifiers.def

+3-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ IDENTIFIER(withKeywordArguments)
133133
IDENTIFIER(wrapped)
134134
IDENTIFIER(wrappedValue)
135135
IDENTIFIER(wrapperValue)
136-
IDENTIFIER(differential)
137-
IDENTIFIER(pullback)
138136

139137
// Kinds of layout constraints
140138
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")
@@ -206,7 +204,10 @@ IDENTIFIER_(nsError)
206204
IDENTIFIER(OSLogMessage)
207205

208206
// Differentiable programming
207+
IDENTIFIER(differential)
208+
IDENTIFIER(pullback)
209209
IDENTIFIER(TangentVector)
210+
IDENTIFIER(zero)
210211

211212
#undef IDENTIFIER
212213
#undef IDENTIFIER_

‎include/swift/AST/KnownProtocols.def

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ PROTOCOL_(SwiftNewtypeWrapper)
7878
PROTOCOL(CodingKey)
7979
PROTOCOL(Encodable)
8080
PROTOCOL(Decodable)
81-
PROTOCOL(AdditiveArithmetic)
8281

8382
PROTOCOL_(ObjectiveCBridgeable)
8483
PROTOCOL_(DestructorSafeContainer)
8584

8685
PROTOCOL(StringInterpolationProtocol)
8786

87+
PROTOCOL(AdditiveArithmetic)
8888
PROTOCOL(Differentiable)
8989

9090
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)

‎include/swift/AST/TypeCheckRequests.h

+25
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,31 @@ class SynthesizeMemberwiseInitRequest
16001600
bool isCached() const { return true; }
16011601
};
16021602

1603+
/// Resolves the effective memberwise initializer for a given type.
1604+
///
1605+
/// An effective memberwise initializer is either a synthesized memberwise
1606+
/// initializer or a user-defined initializer with the same type.
1607+
///
1608+
/// See `NominalTypeDecl::getEffectiveMemberwiseInitializer` for details.
1609+
class ResolveEffectiveMemberwiseInitRequest
1610+
: public SimpleRequest<ResolveEffectiveMemberwiseInitRequest,
1611+
ConstructorDecl *(NominalTypeDecl *),
1612+
CacheKind::Cached> {
1613+
public:
1614+
using SimpleRequest::SimpleRequest;
1615+
1616+
private:
1617+
friend SimpleRequest;
1618+
1619+
// Evaluation.
1620+
llvm::Expected<ConstructorDecl *> evaluate(Evaluator &evaluator,
1621+
NominalTypeDecl *decl) const;
1622+
1623+
public:
1624+
// Caching.
1625+
bool isCached() const { return true; }
1626+
};
1627+
16031628
/// Checks whether this type has a synthesized zero parameter default
16041629
/// initializer.
16051630
class HasDefaultInitRequest

‎include/swift/AST/TypeCheckerTypeIDZone.def

+2
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ SWIFT_REQUEST(TypeChecker, SPIGroupsRequest,
214214
Cached, NoLocationInfo)
215215
SWIFT_REQUEST(TypeChecker, SynthesizeMemberwiseInitRequest,
216216
ConstructorDecl *(NominalTypeDecl *), Cached, NoLocationInfo)
217+
SWIFT_REQUEST(TypeChecker, ResolveEffectiveMemberwiseInitRequest,
218+
ConstructorDecl *(NominalTypeDecl *), Cached, NoLocationInfo)
217219
SWIFT_REQUEST(TypeChecker, HasDefaultInitRequest,
218220
bool(NominalTypeDecl *), Cached, NoLocationInfo)
219221
SWIFT_REQUEST(TypeChecker, SynthesizeDefaultInitRequest,

‎include/swift/Basic/LangOptions.h

+4
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ namespace swift {
327327
/// `@differentiable` declaration attribute, etc.
328328
bool EnableExperimentalDifferentiableProgramming = false;
329329

330+
/// Whether to enable experimental `AdditiveArithmetic` derived
331+
/// conformances.
332+
bool EnableExperimentalAdditiveArithmeticDerivedConformances = false;
333+
330334
/// Enable verification when every SubstitutionMap is constructed.
331335
bool VerifyAllSubstitutionMaps = false;
332336

‎include/swift/Option/Options.td

+7-1
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,16 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">,
495495
HelpText<"Disable automatic generation of bridging PCH files">;
496496

497497
// Experimental feature options
498-
def enable_experimental_differentiable_programming : Flag<["-"], "enable-experimental-differentiable-programming">,
498+
def enable_experimental_differentiable_programming :
499+
Flag<["-"], "enable-experimental-differentiable-programming">,
499500
Flags<[FrontendOption]>,
500501
HelpText<"Enable experimental differentiable programming features">;
501502

503+
def enable_experimental_additive_arithmetic_derivation :
504+
Flag<["-"], "enable-experimental-additive-arithmetic-derivation">,
505+
Flags<[FrontendOption]>,
506+
HelpText<"Enable experimental 'AdditiveArithmetic' derived conformances">;
507+
502508
def enable_experimental_concise_pound_file : Flag<["-"],
503509
"enable-experimental-concise-pound-file">,
504510
Flags<[FrontendOption]>,

‎lib/AST/Decl.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -4059,6 +4059,14 @@ ConstructorDecl *NominalTypeDecl::getMemberwiseInitializer() const {
40594059
ctx.evaluator, SynthesizeMemberwiseInitRequest{mutableThis}, nullptr);
40604060
}
40614061

4062+
ConstructorDecl *NominalTypeDecl::getEffectiveMemberwiseInitializer() {
4063+
auto &ctx = getASTContext();
4064+
auto *mutableThis = const_cast<NominalTypeDecl *>(this);
4065+
return evaluateOrDefault(ctx.evaluator,
4066+
ResolveEffectiveMemberwiseInitRequest{mutableThis},
4067+
nullptr);
4068+
}
4069+
40624070
bool NominalTypeDecl::hasDefaultInitializer() const {
40634071
// Currently only structs and classes can have default initializers.
40644072
if (!isa<StructDecl>(this) && !isa<ClassDecl>(this))

‎lib/Frontend/CompilerInvocation.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,15 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
442442
if (Args.hasArg(OPT_fine_grained_dependency_include_intrafile))
443443
Opts.FineGrainedDependenciesIncludeIntrafileOnes = true;
444444

445-
if (Args.hasArg(OPT_enable_experimental_differentiable_programming))
445+
if (Args.hasArg(OPT_enable_experimental_additive_arithmetic_derivation))
446+
Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true;
447+
448+
if (Args.hasArg(OPT_enable_experimental_differentiable_programming)) {
446449
Opts.EnableExperimentalDifferentiableProgramming = true;
450+
// Differentiable programming implies `AdditiveArithmetic` derived
451+
// conformances.
452+
Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true;
453+
}
447454

448455
Opts.DebuggerSupport |= Args.hasArg(OPT_debugger_support);
449456
if (Opts.DebuggerSupport)

‎lib/IRGen/GenMeta.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4634,8 +4634,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
46344634
case KnownProtocolKind::Encodable:
46354635
case KnownProtocolKind::Decodable:
46364636
case KnownProtocolKind::StringInterpolationProtocol:
4637-
case KnownProtocolKind::Differentiable:
46384637
case KnownProtocolKind::AdditiveArithmetic:
4638+
case KnownProtocolKind::Differentiable:
46394639
return SpecialProtocol::None;
46404640
}
46414641

‎lib/Sema/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
add_swift_host_library(swiftSema STATIC
32
BuilderTransform.cpp
43
CSApply.cpp
@@ -16,6 +15,7 @@ add_swift_host_library(swiftSema STATIC
1615
ConstraintLocator.cpp
1716
ConstraintSystem.cpp
1817
DebuggerTestingTransform.cpp
18+
DerivedConformanceAdditiveArithmetic.cpp
1919
DerivedConformanceCaseIterable.cpp
2020
DerivedConformanceCodable.cpp
2121
DerivedConformanceCodingKey.cpp

‎lib/Sema/CodeSynthesis.cpp

+106
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,93 @@ SynthesizeMemberwiseInitRequest::evaluate(Evaluator &evaluator,
12011201
return ctor;
12021202
}
12031203

1204+
llvm::Expected<ConstructorDecl *>
1205+
ResolveEffectiveMemberwiseInitRequest::evaluate(Evaluator &evaluator,
1206+
NominalTypeDecl *decl) const {
1207+
// Compute the access level for the memberwise initializer. The minimum of:
1208+
// - Public, by default. This enables public nominal types to have public
1209+
// memberwise initializers.
1210+
// - The `public` default is important for synthesized member types, e.g.
1211+
// `TangentVector` structs synthesized during `Differentiable` derived
1212+
// conformances. Manually extending these types to define a public
1213+
// memberwise initializer causes a redeclaration error.
1214+
// - The minimum access level of memberwise-initialized properties in the
1215+
// nominal type declaration.
1216+
auto accessLevel = AccessLevel::Public;
1217+
for (auto *member : decl->getMembers()) {
1218+
auto *var = dyn_cast<VarDecl>(member);
1219+
if (!var ||
1220+
!var->isMemberwiseInitialized(/*preferDeclaredProperties*/ true))
1221+
continue;
1222+
accessLevel = std::min(accessLevel, var->getFormalAccess());
1223+
}
1224+
auto &ctx = decl->getASTContext();
1225+
1226+
// If a memberwise initializer exists, set its access level and return it.
1227+
if (auto *initDecl = decl->getMemberwiseInitializer()) {
1228+
initDecl->overwriteAccess(accessLevel);
1229+
return initDecl;
1230+
}
1231+
1232+
auto isEffectiveMemberwiseInitializer = [&](ConstructorDecl *initDecl) {
1233+
// Check for `nullptr`.
1234+
if (!initDecl)
1235+
return false;
1236+
// Get all stored properties, excluding `let` properties with initial
1237+
// values.
1238+
SmallVector<VarDecl *, 8> storedProperties;
1239+
for (auto *vd : decl->getStoredProperties()) {
1240+
if (vd->isLet() && vd->hasInitialValue())
1241+
continue;
1242+
storedProperties.push_back(vd);
1243+
}
1244+
// Return false if initializer does not have interface type set. It is not
1245+
// possible to determine whether it is a memberwise initializer.
1246+
if (!initDecl->hasInterfaceType())
1247+
return false;
1248+
auto initDeclType =
1249+
initDecl->getMethodInterfaceType()->getAs<AnyFunctionType>();
1250+
// Return false if initializer does not have a valid interface type.
1251+
if (!initDeclType)
1252+
return false;
1253+
// Return false if stored property count does not have parameter count.
1254+
if (storedProperties.size() != initDeclType->getNumParams())
1255+
return false;
1256+
// Return true if all stored property types/names match initializer
1257+
// parameter types/labels.
1258+
return llvm::all_of(
1259+
llvm::zip(storedProperties, initDeclType->getParams()),
1260+
[&](std::tuple<VarDecl *, AnyFunctionType::Param> pair) {
1261+
auto *storedProp = std::get<0>(pair);
1262+
auto param = std::get<1>(pair);
1263+
return storedProp->getInterfaceType()->isEqual(
1264+
param.getPlainType()) &&
1265+
storedProp->getName() == param.getLabel();
1266+
});
1267+
};
1268+
1269+
// Otherwise, look for a user-defined effective memberwise initializer.
1270+
ConstructorDecl *memberwiseInitDecl = nullptr;
1271+
auto initDecls = decl->lookupDirect(DeclBaseName::createConstructor());
1272+
for (auto *decl : initDecls) {
1273+
auto *initDecl = dyn_cast<ConstructorDecl>(decl);
1274+
if (!isEffectiveMemberwiseInitializer(initDecl))
1275+
continue;
1276+
assert(!memberwiseInitDecl && "Memberwise initializer already found");
1277+
memberwiseInitDecl = initDecl;
1278+
}
1279+
1280+
// Otherwise, create a memberwise initializer, set its access level, and
1281+
// return it.
1282+
if (!memberwiseInitDecl) {
1283+
memberwiseInitDecl = createImplicitConstructor(
1284+
decl, ImplicitConstructorKind::Memberwise, ctx);
1285+
memberwiseInitDecl->overwriteAccess(accessLevel);
1286+
decl->addMember(memberwiseInitDecl);
1287+
}
1288+
return memberwiseInitDecl;
1289+
}
1290+
12041291
llvm::Expected<bool>
12051292
HasDefaultInitRequest::evaluate(Evaluator &evaluator,
12061293
NominalTypeDecl *decl) const {
@@ -1263,3 +1350,22 @@ SynthesizeDefaultInitRequest::evaluate(Evaluator &evaluator,
12631350
ctor->setBodySynthesizer(synthesizeSingleReturnFunctionBody);
12641351
return ctor;
12651352
}
1353+
1354+
ValueDecl *swift::getProtocolRequirement(ProtocolDecl *protocol,
1355+
Identifier name) {
1356+
auto lookup = protocol->lookupDirect(name);
1357+
// Erase declarations that are not protocol requirements.
1358+
// This is important for removing default implementations of the same name.
1359+
llvm::erase_if(lookup, [](ValueDecl *v) {
1360+
return !isa<ProtocolDecl>(v->getDeclContext()) ||
1361+
!v->isProtocolRequirement();
1362+
});
1363+
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
1364+
return lookup.front();
1365+
}
1366+
1367+
bool swift::hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
1368+
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
1369+
return v->isLet() && v->hasInitialValue();
1370+
});
1371+
}

‎lib/Sema/CodeSynthesis.h

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ Expr *buildSelfReference(VarDecl *selfDecl,
6060
Expr *buildArgumentForwardingExpr(ArrayRef<ParamDecl*> params,
6161
ASTContext &ctx);
6262

63+
/// Returns the protocol requirement with the specified name.
64+
ValueDecl *getProtocolRequirement(ProtocolDecl *protocol, Identifier name);
65+
66+
// Returns true if given nominal type has a `let` stored with an initial value.
67+
bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal);
68+
6369
} // end namespace swift
6470

6571
#endif

0 commit comments

Comments
 (0)