Skip to content

Commit 1a4ea67

Browse files
committed
[AutoDiff] Plumb witness derivative generic signatures through SILGen.
When SILGenWitnessTable creates a decl ref for the witness of a derivative function requirement, it is using the requirement's derivative generic signature in the resulting witness decl ref. This is wrong because the witness may have a different derivative generic signature than the requirement, leading to a crash. This bug was never discovered because GSB's dark magic made it "just work", until requirement machine. The fix is to store the matched witness derivative generic signature in `Witness` during type checking, and during witness table generation, use the witness' generic signature to create a witness decl ref. Resolves rdar://84716758, rdar://84213107 and rdar://84987079.
1 parent b37f77f commit 1a4ea67

8 files changed

+117
-52
lines changed

include/swift/AST/Witness.h

+17-2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class Witness {
9494
ConcreteDeclRef declRef;
9595
GenericEnvironment *syntheticEnvironment;
9696
SubstitutionMap reqToSyntheticEnvSubs;
97+
/// The derivative generic signature, when the requirement is a derivative
98+
/// function.
99+
GenericSignature derivativeGenSig;
97100
};
98101

99102
llvm::PointerUnion<ValueDecl *, StoredWitness *> storage;
@@ -124,7 +127,8 @@ class Witness {
124127
static Witness forDeserialized(ValueDecl *decl,
125128
SubstitutionMap substitutions) {
126129
// TODO: It's probably a good idea to have a separate 'deserialized' bit.
127-
return Witness(decl, substitutions, nullptr, SubstitutionMap());
130+
return Witness(
131+
decl, substitutions, nullptr, SubstitutionMap(), CanGenericSignature());
128132
}
129133

130134
/// Create a witness that requires substitutions.
@@ -138,10 +142,14 @@ class Witness {
138142
///
139143
/// \param reqToSyntheticEnvSubs The mapping from the interface types of the
140144
/// requirement into the interface types of the synthetic environment.
145+
///
146+
/// \param derivativeGenSig The derivative generic signature, when the
147+
/// requirement is a derivative function.
141148
Witness(ValueDecl *decl,
142149
SubstitutionMap substitutions,
143150
GenericEnvironment *syntheticEnv,
144-
SubstitutionMap reqToSyntheticEnvSubs);
151+
SubstitutionMap reqToSyntheticEnvSubs,
152+
GenericSignature derivativeGenSig);
145153

146154
/// Retrieve the witness declaration reference, which includes the
147155
/// substitutions needed to use the witness from the synthetic environment
@@ -183,6 +191,13 @@ class Witness {
183191
return {};
184192
}
185193

194+
/// Retrieve the derivative generic signature.
195+
GenericSignature getDerivativeGenericSignature() const {
196+
if (auto *storedWitness = storage.dyn_cast<StoredWitness *>())
197+
return storedWitness->derivativeGenSig;
198+
return GenericSignature();
199+
}
200+
186201
SWIFT_DEBUG_DUMP;
187202

188203
void dump(llvm::raw_ostream &out) const;

lib/AST/ProtocolConformance.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ using namespace swift;
4242

4343
Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
4444
GenericEnvironment *syntheticEnv,
45-
SubstitutionMap reqToSynthesizedEnvSubs) {
45+
SubstitutionMap reqToSynthesizedEnvSubs,
46+
GenericSignature derivativeGenSig) {
4647
if (!syntheticEnv && substitutions.empty() &&
4748
reqToSynthesizedEnvSubs.empty()) {
4849
storage = decl;
@@ -53,7 +54,8 @@ Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
5354
auto declRef = ConcreteDeclRef(decl, substitutions);
5455
auto storedMem = ctx.Allocate(sizeof(StoredWitness), alignof(StoredWitness));
5556
auto stored = new (storedMem) StoredWitness{declRef, syntheticEnv,
56-
reqToSynthesizedEnvSubs};
57+
reqToSynthesizedEnvSubs,
58+
derivativeGenSig};
5759

5860
storage = stored;
5961
}
@@ -892,7 +894,8 @@ NormalProtocolConformance::getWitnessUncached(ValueDecl *requirement) const {
892894
}
893895

894896
Witness SelfProtocolConformance::getWitness(ValueDecl *requirement) const {
895-
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap());
897+
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap(),
898+
GenericSignature());
896899
}
897900

898901
ConcreteDeclRef

lib/SIL/IR/TypeLowering.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -2631,7 +2631,8 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
26312631
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
26322632
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
26332633
derivativeId->getParameterIndices(), derivativeId->getKind(),
2634-
LookUpConformanceInModule(&M));
2634+
LookUpConformanceInModule(&M),
2635+
derivativeId->getDerivativeGenericSignature());
26352636
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
26362637
}
26372638

lib/SILGen/SILGenType.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
416416
if (!reqAccessor) {
417417
if (auto witness = asDerived().getWitness(reqDecl)) {
418418
return addMethodImplementation(
419-
requirementRef, requirementRef.withDecl(witness.getDecl()),
419+
requirementRef, getWitnessRef(requirementRef, witness),
420420
witness);
421421
}
422422

@@ -444,7 +444,8 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
444444
witnessStorage->getSynthesizedAccessor(reqAccessor->getAccessorKind());
445445

446446
return addMethodImplementation(
447-
requirementRef, requirementRef.withDecl(witnessAccessor), witness);
447+
requirementRef, getWitnessRef(requirementRef, witnessAccessor),
448+
witness);
448449
}
449450

450451
private:
@@ -458,6 +459,21 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
458459
asDerived().addMethodImplementation(requirementRef, witnessRef,
459460
isFree, witness);
460461
}
462+
463+
SILDeclRef getWitnessRef(SILDeclRef requirementRef, Witness witness) {
464+
auto witnessRef = requirementRef.withDecl(witness.getDecl());
465+
// If the requirement/witness is a derivative function, we need to
466+
// substitute the witness's derivative generic signature in its derivative
467+
// function identifier.
468+
if (requirementRef.isAutoDiffDerivativeFunction()) {
469+
auto *reqrRerivativeId = requirementRef.getDerivativeFunctionIdentifier();
470+
auto *witnessDerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
471+
reqrRerivativeId->getKind(), reqrRerivativeId->getParameterIndices(),
472+
witness.getDerivativeGenericSignature(), witnessRef.getASTContext());
473+
witnessRef = witnessRef.asAutoDiffDerivativeFunction(witnessDerivativeId);
474+
}
475+
return witnessRef;
476+
}
461477
};
462478

463479
static IsSerialized_t isConformanceSerialized(RootProtocolConformance *conf) {

lib/Sema/TypeCheckProtocol.cpp

+19-16
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ struct swift::RequirementCheck {
9898
swift::Witness RequirementMatch::getWitness(ASTContext &ctx) const {
9999
auto syntheticEnv = ReqEnv->getSyntheticEnvironment();
100100
return swift::Witness(this->Witness, WitnessSubstitutions,
101-
syntheticEnv, ReqEnv->getRequirementToSyntheticMap());
101+
syntheticEnv, ReqEnv->getRequirementToSyntheticMap(),
102+
DerivativeGenSig);
102103
}
103104

104105
AssociatedTypeDecl *
@@ -306,17 +307,16 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
306307
/// Given a witness, a requirement, and an existing `RequirementMatch` result,
307308
/// check if the requirement's `@differentiable` attributes are met by the
308309
/// witness.
309-
/// - If requirement's `@differentiable` attributes are met, or if `result` is
310-
/// not viable, returns `result`.
310+
/// - If `result` is not viable, do nothing.
311+
/// - If requirement's `@differentiable` attributes are met, update `result`
312+
/// with the matched derivative generic signature.
311313
/// - Otherwise, returns a "missing `@differentiable` attribute"
312314
/// `RequirementMatch`.
313-
// Note: the `result` argument is only necessary for using
314-
// `RequirementMatch::WitnessSubstitutions`.
315-
static RequirementMatch
315+
static void
316316
matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
317-
ValueDecl *witness, RequirementMatch result) {
317+
ValueDecl *witness, RequirementMatch &result) {
318318
if (!result.isViable())
319-
return result;
319+
return;
320320

321321
// Get the requirement and witness attributes.
322322
const auto &reqAttrs = req->getAttrs();
@@ -377,6 +377,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
377377
if (witnessConfig.parameterIndices ==
378378
reqDiffAttr->getParameterIndices()) {
379379
foundExactConfig = true;
380+
// Store the matched witness derivative generic signature.
381+
result.DerivativeGenSig = witnessConfig.derivativeGenericSignature;
380382
break;
381383
}
382384
if (witnessConfig.parameterIndices->isSupersetOf(
@@ -407,12 +409,12 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
407409
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
408410
// appear if associated type inference is involved.
409411
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
410-
return RequirementMatch(
412+
result = RequirementMatch(
411413
getStandinForAccessor(vdWitness, AccessorKind::Get),
412414
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
413415
} else {
414-
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
415-
reqDiffAttr);
416+
result = RequirementMatch(
417+
witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr);
416418
}
417419
}
418420

@@ -461,6 +463,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
461463
witnessAFD->addDerivativeFunctionConfiguration(
462464
{newAttr->getParameterIndices(), resultIndices,
463465
newAttr->getDerivativeGenericSignature()});
466+
// Store the witness derivative generic signature.
467+
result.DerivativeGenSig = newAttr->getDerivativeGenericSignature();
464468
}
465469
}
466470
if (!success) {
@@ -475,17 +479,16 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
475479
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
476480
// appear if associated type inference is involved.
477481
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
478-
return RequirementMatch(
482+
result = RequirementMatch(
479483
getStandinForAccessor(vdWitness, AccessorKind::Get),
480484
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
481485
} else {
482-
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
483-
reqDiffAttr);
486+
result = RequirementMatch(
487+
witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr);
484488
}
485489
}
486490
}
487491
}
488-
return result;
489492
}
490493

491494
/// A property or subscript witness must have the same or fewer
@@ -817,7 +820,7 @@ swift::matchWitness(
817820
auto result = finalize(anyRenaming, optionalAdjustments);
818821
// Check if the requirement's `@differentiable` attributes are satisfied by
819822
// the witness.
820-
result = matchWitnessDifferentiableAttr(dc, req, witness, result);
823+
matchWitnessDifferentiableAttr(dc, req, witness, result);
821824
return result;
822825
}
823826

lib/Sema/TypeCheckProtocol.h

+11-4
Original file line numberDiff line numberDiff line change
@@ -435,23 +435,27 @@ struct RequirementMatch {
435435
RequirementMatch(ValueDecl *witness, MatchKind kind,
436436
Type witnessType,
437437
Optional<RequirementEnvironment> env = None,
438-
ArrayRef<OptionalAdjustment> optionalAdjustments = {})
438+
ArrayRef<OptionalAdjustment> optionalAdjustments = {},
439+
GenericSignature derivativeGenSig = GenericSignature())
439440
: Witness(witness), Kind(kind), WitnessType(witnessType),
440441
ReqEnv(std::move(env)),
441442
OptionalAdjustments(optionalAdjustments.begin(),
442-
optionalAdjustments.end())
443+
optionalAdjustments.end()),
444+
DerivativeGenSig(derivativeGenSig)
443445
{
444446
assert(hasWitnessType() == !witnessType.isNull() &&
445447
"Should (or should not) have witness type");
446448
}
447449

448450
RequirementMatch(ValueDecl *witness, MatchKind kind, Requirement requirement,
449451
Optional<RequirementEnvironment> env = None,
450-
ArrayRef<OptionalAdjustment> optionalAdjustments = {})
452+
ArrayRef<OptionalAdjustment> optionalAdjustments = {},
453+
GenericSignature derivativeGenSig = GenericSignature())
451454
: Witness(witness), Kind(kind), WitnessType(requirement.getFirstType()),
452455
MissingRequirement(requirement), ReqEnv(std::move(env)),
453456
OptionalAdjustments(optionalAdjustments.begin(),
454-
optionalAdjustments.end()) {
457+
optionalAdjustments.end()),
458+
DerivativeGenSig(derivativeGenSig) {
455459
assert(hasWitnessType() && hasRequirement() &&
456460
"Should have witness type and requirement");
457461
}
@@ -481,6 +485,9 @@ struct RequirementMatch {
481485
/// environment.
482486
SubstitutionMap WitnessSubstitutions;
483487

488+
/// The matched derivative generic signature.
489+
GenericSignature DerivativeGenSig;
490+
484491
/// Determine whether this match is viable.
485492
bool isViable() const {
486493
switch(Kind) {

test/AutoDiff/compiler_crashers_fixed/rdar83894546-missing-generic-requirement.swift

-24
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %target-build-swift %s
2+
3+
import _Differentiation
4+
5+
public protocol Layer: Differentiable {
6+
associatedtype Input: Differentiable
7+
associatedtype Output: Differentiable
8+
@differentiable(reverse, wrt: (self, input))
9+
@differentiable(reverse, wrt: input)
10+
func callAsFunction(_ input: Input) -> Output
11+
}
12+
13+
// Test for explicitly declared `@differentiable` attributes.
14+
public class Function1<Input: Differentiable, Output: Differentiable>: Layer {
15+
@noDerivative public let body: @differentiable(reverse) (Input) -> Output
16+
17+
public init(_ body: @escaping @differentiable(reverse) (Input) -> Output) {
18+
self.body = body
19+
}
20+
21+
@differentiable(reverse, wrt: (self, input))
22+
@differentiable(reverse, wrt: input)
23+
public func callAsFunction(_ input: Input) -> Output {
24+
body(input)
25+
}
26+
27+
@differentiable(reverse, wrt: x where T: Differentiable)
28+
public func foo<T>(x: T) -> T {
29+
x
30+
}
31+
}
32+
33+
// Test for implicitly generated `@differentiable` attributes.
34+
class Function2<Input: Differentiable, Output: Differentiable>: Layer {
35+
@noDerivative let body: @differentiable(reverse) (Input) -> Output
36+
37+
init(_ body: @escaping @differentiable(reverse) (Input) -> Output) {
38+
self.body = body
39+
}
40+
41+
func callAsFunction(_ input: Input) -> Output {
42+
body(input)
43+
}
44+
}

0 commit comments

Comments
 (0)