@@ -98,7 +98,8 @@ struct swift::RequirementCheck {
98
98
swift::Witness RequirementMatch::getWitness(ASTContext &ctx) const {
99
99
auto syntheticEnv = ReqEnv->getSyntheticEnvironment();
100
100
return swift::Witness(this->Witness, WitnessSubstitutions,
101
- syntheticEnv, ReqEnv->getRequirementToSyntheticMap());
101
+ syntheticEnv, ReqEnv->getRequirementToSyntheticMap(),
102
+ DerivativeGenSig);
102
103
}
103
104
104
105
AssociatedTypeDecl *
@@ -306,17 +307,16 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
306
307
/// Given a witness, a requirement, and an existing `RequirementMatch` result,
307
308
/// check if the requirement's `@differentiable` attributes are met by the
308
309
/// witness.
309
- /// - If requirement's `@differentiable` attributes are met, or if `result` is
310
- /// not viable, returns `result`.
310
+ /// - If `result` is not viable, do nothing.
311
+ /// - If requirement's `@differentiable` attributes are met, update `result`
312
+ /// with the matched derivative generic signature.
311
313
/// - Otherwise, returns a "missing `@differentiable` attribute"
312
314
/// `RequirementMatch`.
313
- // Note: the `result` argument is only necessary for using
314
- // `RequirementMatch::WitnessSubstitutions`.
315
- static RequirementMatch
315
+ static void
316
316
matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
317
- ValueDecl *witness, RequirementMatch result) {
317
+ ValueDecl *witness, RequirementMatch & result) {
318
318
if (!result.isViable())
319
- return result ;
319
+ return;
320
320
321
321
// Get the requirement and witness attributes.
322
322
const auto &reqAttrs = req->getAttrs();
@@ -377,6 +377,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
377
377
if (witnessConfig.parameterIndices ==
378
378
reqDiffAttr->getParameterIndices()) {
379
379
foundExactConfig = true;
380
+ // Store the matched witness derivative generic signature.
381
+ result.DerivativeGenSig = witnessConfig.derivativeGenericSignature;
380
382
break;
381
383
}
382
384
if (witnessConfig.parameterIndices->isSupersetOf(
@@ -407,12 +409,12 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
407
409
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
408
410
// appear if associated type inference is involved.
409
411
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
410
- return RequirementMatch(
412
+ result = RequirementMatch(
411
413
getStandinForAccessor(vdWitness, AccessorKind::Get),
412
414
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
413
415
} else {
414
- return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
415
- reqDiffAttr);
416
+ result = RequirementMatch(
417
+ witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr);
416
418
}
417
419
}
418
420
@@ -461,6 +463,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
461
463
witnessAFD->addDerivativeFunctionConfiguration(
462
464
{newAttr->getParameterIndices(), resultIndices,
463
465
newAttr->getDerivativeGenericSignature()});
466
+ // Store the witness derivative generic signature.
467
+ result.DerivativeGenSig = newAttr->getDerivativeGenericSignature();
464
468
}
465
469
}
466
470
if (!success) {
@@ -475,17 +479,16 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
475
479
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
476
480
// appear if associated type inference is involved.
477
481
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
478
- return RequirementMatch(
482
+ result = RequirementMatch(
479
483
getStandinForAccessor(vdWitness, AccessorKind::Get),
480
484
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
481
485
} else {
482
- return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
483
- reqDiffAttr);
486
+ result = RequirementMatch(
487
+ witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr);
484
488
}
485
489
}
486
490
}
487
491
}
488
- return result;
489
492
}
490
493
491
494
/// A property or subscript witness must have the same or fewer
@@ -817,7 +820,7 @@ swift::matchWitness(
817
820
auto result = finalize(anyRenaming, optionalAdjustments);
818
821
// Check if the requirement's `@differentiable` attributes are satisfied by
819
822
// the witness.
820
- result = matchWitnessDifferentiableAttr(dc, req, witness, result);
823
+ matchWitnessDifferentiableAttr(dc, req, witness, result);
821
824
return result;
822
825
}
823
826
0 commit comments