@@ -935,43 +935,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
935
935
emitDifferentiabilityWitnessesForFunction (constant, F);
936
936
}
937
937
938
- // / Returns the SIL differentiability witness generic signature given the
939
- // / original declaration's generic signature and the derivative generic
940
- // / signature.
941
- // /
942
- // / In general, the differentiability witness generic signature is equal to the
943
- // / derivative generic signature.
944
- // /
945
- // / Edge case, if two conditions are satisfied:
946
- // / 1. The derivative generic signature is equal to the original generic
947
- // / signature.
948
- // / 2. The derivative generic signature has *all concrete* generic parameters
949
- // / (i.e. all generic parameters are bound to concrete types via same-type
950
- // / requirements).
951
- // /
952
- // / Then the differentiability witness generic signature is `nullptr`.
953
- // /
954
- // / Both the original and derivative declarations are lowered to SIL functions
955
- // / with a fully concrete type and no generic signature, so the
956
- // / differentiability witness should similarly have no generic signature.
957
- static GenericSignature
958
- getDifferentiabilityWitnessGenericSignature (GenericSignature origGenSig,
959
- GenericSignature derivativeGenSig) {
960
- // If there is no derivative generic signature, return the original generic
961
- // signature.
962
- if (!derivativeGenSig)
963
- return origGenSig;
964
- // If derivative generic signature has all concrete generic parameters and is
965
- // equal to the original generic signature, return `nullptr`.
966
- auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature ();
967
- auto origCanGenSig = origGenSig.getCanonicalSignature ();
968
- if (origCanGenSig == derivativeCanGenSig &&
969
- derivativeCanGenSig->areAllParamsConcrete ())
970
- return GenericSignature ();
971
- // Otherwise, return the derivative generic signature.
972
- return derivativeGenSig;
973
- }
974
-
975
938
void SILGenModule::emitDifferentiabilityWitnessesForFunction (
976
939
SILDeclRef constant, SILFunction *F) {
977
940
// Visit `@derivative` attributes and generate SIL differentiability
@@ -992,9 +955,10 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
992
955
diffAttr->getDerivativeGenericSignature ()) &&
993
956
" Type-checking should resolve derivative generic signatures for "
994
957
" all original SIL functions with generic signatures" );
995
- auto witnessGenSig = getDifferentiabilityWitnessGenericSignature (
996
- AFD->getGenericSignature (),
997
- diffAttr->getDerivativeGenericSignature ());
958
+ auto witnessGenSig =
959
+ autodiff::getDifferentiabilityWitnessGenericSignature (
960
+ AFD->getGenericSignature (),
961
+ diffAttr->getDerivativeGenericSignature ());
998
962
AutoDiffConfig config (diffAttr->getParameterIndices (), resultIndices,
999
963
witnessGenSig);
1000
964
emitDifferentiabilityWitness (AFD, F, config, /* jvp*/ nullptr ,
@@ -1015,8 +979,9 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
1015
979
auto origDeclRef =
1016
980
SILDeclRef (origAFD).asForeign (requiresForeignEntryPoint (origAFD));
1017
981
auto *origFn = getFunction (origDeclRef, NotForDefinition);
1018
- auto witnessGenSig = getDifferentiabilityWitnessGenericSignature (
1019
- origAFD->getGenericSignature (), AFD->getGenericSignature ());
982
+ auto witnessGenSig =
983
+ autodiff::getDifferentiabilityWitnessGenericSignature (
984
+ origAFD->getGenericSignature (), AFD->getGenericSignature ());
1020
985
auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
1021
986
AutoDiffConfig config (derivAttr->getParameterIndices (), resultIndices,
1022
987
witnessGenSig);
0 commit comments