Skip to content

Commit 7191c9c

Browse files
author
Marc Rasi
committed
[AutoDiff] remove all-concrete gen sig from more places
1 parent fb74ba2 commit 7191c9c

File tree

8 files changed

+89
-50
lines changed

8 files changed

+89
-50
lines changed

include/swift/AST/AutoDiff.h

+23
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,29 @@ bool getBuiltinDifferentiableOrLinearFunctionConfig(
649649
bool getBuiltinDifferentiableOrLinearFunctionConfig(
650650
StringRef operationName, unsigned &arity, bool &throws);
651651

652+
/// Returns the SIL differentiability witness generic signature given the
653+
/// original declaration's generic signature and the derivative generic
654+
/// signature.
655+
///
656+
/// In general, the differentiability witness generic signature is equal to the
657+
/// derivative generic signature.
658+
///
659+
/// Edge case, if two conditions are satisfied:
660+
/// 1. The derivative generic signature is equal to the original generic
661+
/// signature.
662+
/// 2. The derivative generic signature has *all concrete* generic parameters
663+
/// (i.e. all generic parameters are bound to concrete types via same-type
664+
/// requirements).
665+
///
666+
/// Then the differentiability witness generic signature is `nullptr`.
667+
///
668+
/// Both the original and derivative declarations are lowered to SIL functions
669+
/// with a fully concrete type and no generic signature, so the
670+
/// differentiability witness should similarly have no generic signature.
671+
GenericSignature
672+
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
673+
GenericSignature derivativeGenSig);
674+
652675
} // end namespace autodiff
653676

654677
} // end namespace swift

lib/AST/AutoDiff.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,23 @@ bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
372372
return operationName.empty();
373373
}
374374

375+
GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
376+
GenericSignature origGenSig, GenericSignature derivativeGenSig) {
377+
// If there is no derivative generic signature, return the original generic
378+
// signature.
379+
if (!derivativeGenSig)
380+
return origGenSig;
381+
// If derivative generic signature has all concrete generic parameters and is
382+
// equal to the original generic signature, return `nullptr`.
383+
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
384+
auto origCanGenSig = origGenSig.getCanonicalSignature();
385+
if (origCanGenSig == derivativeCanGenSig &&
386+
derivativeCanGenSig->areAllParamsConcrete())
387+
return GenericSignature();
388+
// Otherwise, return the derivative generic signature.
389+
return derivativeGenSig;
390+
}
391+
375392
Type TangentSpace::getType() const {
376393
switch (kind) {
377394
case Kind::TangentVector:

lib/SILGen/SILGen.cpp

+7-42
Original file line numberDiff line numberDiff line change
@@ -935,43 +935,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
935935
emitDifferentiabilityWitnessesForFunction(constant, F);
936936
}
937937

938-
/// Returns the SIL differentiability witness generic signature given the
939-
/// original declaration's generic signature and the derivative generic
940-
/// signature.
941-
///
942-
/// In general, the differentiability witness generic signature is equal to the
943-
/// derivative generic signature.
944-
///
945-
/// Edge case, if two conditions are satisfied:
946-
/// 1. The derivative generic signature is equal to the original generic
947-
/// signature.
948-
/// 2. The derivative generic signature has *all concrete* generic parameters
949-
/// (i.e. all generic parameters are bound to concrete types via same-type
950-
/// requirements).
951-
///
952-
/// Then the differentiability witness generic signature is `nullptr`.
953-
///
954-
/// Both the original and derivative declarations are lowered to SIL functions
955-
/// with a fully concrete type and no generic signature, so the
956-
/// differentiability witness should similarly have no generic signature.
957-
static GenericSignature
958-
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
959-
GenericSignature derivativeGenSig) {
960-
// If there is no derivative generic signature, return the original generic
961-
// signature.
962-
if (!derivativeGenSig)
963-
return origGenSig;
964-
// If derivative generic signature has all concrete generic parameters and is
965-
// equal to the original generic signature, return `nullptr`.
966-
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
967-
auto origCanGenSig = origGenSig.getCanonicalSignature();
968-
if (origCanGenSig == derivativeCanGenSig &&
969-
derivativeCanGenSig->areAllParamsConcrete())
970-
return GenericSignature();
971-
// Otherwise, return the derivative generic signature.
972-
return derivativeGenSig;
973-
}
974-
975938
void SILGenModule::emitDifferentiabilityWitnessesForFunction(
976939
SILDeclRef constant, SILFunction *F) {
977940
// Visit `@derivative` attributes and generate SIL differentiability
@@ -992,9 +955,10 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
992955
diffAttr->getDerivativeGenericSignature()) &&
993956
"Type-checking should resolve derivative generic signatures for "
994957
"all original SIL functions with generic signatures");
995-
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
996-
AFD->getGenericSignature(),
997-
diffAttr->getDerivativeGenericSignature());
958+
auto witnessGenSig =
959+
autodiff::getDifferentiabilityWitnessGenericSignature(
960+
AFD->getGenericSignature(),
961+
diffAttr->getDerivativeGenericSignature());
998962
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
999963
witnessGenSig);
1000964
emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr,
@@ -1015,8 +979,9 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
1015979
auto origDeclRef =
1016980
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
1017981
auto *origFn = getFunction(origDeclRef, NotForDefinition);
1018-
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
1019-
origAFD->getGenericSignature(), AFD->getGenericSignature());
982+
auto witnessGenSig =
983+
autodiff::getDifferentiabilityWitnessGenericSignature(
984+
origAFD->getGenericSignature(), AFD->getGenericSignature());
1020985
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
1021986
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
1022987
witnessGenSig);

lib/SILOptimizer/Differentiation/Common.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,11 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
452452
silParameterIndices->getNumIndices() <
453453
minimalConfig->parameterIndices->getNumIndices())) {
454454
minimalASTParameterIndices = config.parameterIndices;
455-
minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices,
456-
config.derivativeGenericSignature);
455+
minimalConfig =
456+
AutoDiffConfig(silParameterIndices, config.resultIndices,
457+
autodiff::getDifferentiabilityWitnessGenericSignature(
458+
original->getGenericSignature(),
459+
config.derivativeGenericSignature));
457460
}
458461
}
459462
return minimalConfig;

lib/TBDGen/TBDGen.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,10 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
530530
config.parameterIndices,
531531
original->getInterfaceType()->castTo<AnyFunctionType>());
532532
Mangle::ASTMangler mangler;
533-
AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices,
534-
config.derivativeGenericSignature};
533+
AutoDiffConfig silConfig{
534+
loweredParamIndices, config.resultIndices,
535+
autodiff::getDifferentiabilityWitnessGenericSignature(
536+
original->getGenericSignature(), config.derivativeGenericSignature)};
535537
std::string linearMapName =
536538
mangler.mangleAutoDiffLinearMapHelper(declRef.mangle(), kind, silConfig);
537539
addSymbol(linearMapName);
@@ -542,7 +544,9 @@ void TBDGenVisitor::addAutoDiffDerivativeFunction(
542544
GenericSignature derivativeGenericSignature,
543545
AutoDiffDerivativeFunctionKind kind) {
544546
auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get(
545-
kind, parameterIndices, derivativeGenericSignature,
547+
kind, parameterIndices,
548+
autodiff::getDifferentiabilityWitnessGenericSignature(
549+
original->getGenericSignature(), derivativeGenericSignature),
546550
original->getASTContext());
547551
auto declRef =
548552
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
@@ -569,8 +573,10 @@ void TBDGenVisitor::addDifferentiabilityWitness(
569573
original->getInterfaceType()->castTo<AnyFunctionType>());
570574

571575
auto originalMangledName = declRef.mangle();
572-
AutoDiffConfig config{silParamIndices, resultIndices,
573-
derivativeGenericSignature};
576+
AutoDiffConfig config{
577+
silParamIndices, resultIndices,
578+
autodiff::getDifferentiabilityWitnessGenericSignature(
579+
original->getGenericSignature(), derivativeGenericSignature)};
574580
SILDifferentiabilityWitnessKey key(originalMangledName, config);
575581

576582
Mangle::ASTMangler mangler;

test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift

+11
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,14 @@ class Class: Differentiable {
4141
set {}
4242
}
4343
}
44+
45+
struct S: Differentiable {
46+
var value: Float
47+
}
48+
49+
extension Array where Element == S {
50+
@differentiable
51+
func sum() -> Float {
52+
return 0
53+
}
54+
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift

+7
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ func classRequirementSetters(_ x: inout Class, _ newValue: Float) {
5757
x.property = newValue
5858
x[] = newValue
5959
}
60+
61+
// Test cross-file lookup of a derivative function with all-concrete derivative generic signature.
62+
@differentiable
63+
func allConcreteDerivativeGenericSignature(_ a: [S]) -> Float {
64+
// No error expected.
65+
return a.sum()
66+
}

test/AutoDiff/TBD/derivative_symbols.swift

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
1919
fatalError()
2020
}
2121

22-
struct Struct: Differentiable {
22+
public struct Struct: Differentiable {
2323
var stored: Float
2424

2525
// Test property.
@@ -54,3 +54,10 @@ struct Struct: Differentiable {
5454
fatalError()
5555
}
5656
}
57+
58+
extension Array where Element == Struct {
59+
@differentiable
60+
public func sum() -> Float {
61+
return 0
62+
}
63+
}

0 commit comments

Comments
 (0)