Skip to content

Commit 82886bf

Browse files
committed
[AutoDiff] Fix mangling of '@noDerivative' in function types.
`@noDerivative` was not mangled in function types, and was resolved incorrectly when there's an ownership specifier. It is fixed by this patch with the following changes: * Add `NoDerivative` demangle node represented by a `k` operator. ``` list-type ::= type identifier? 'k'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier ``` * Fix `NoDerivative`'s overflown offset in `ParameterTypeFlags` (`7` -> `6`). * In type decoder and type resolver where attributed type nodes are processed, add support for nested attributed nodes, e.g. `inout @noDerivative T`. * Add `TypeResolverContext::InoutFunctionInput` so that when we resolve an `inout @noDerivative T` parameter, the `@noDerivative T` checking logic won't get a `TypeResolverContext::None` set by the caller. Resolves rdar://75916833.
1 parent 7b5b474 commit 82886bf

File tree

18 files changed

+144
-52
lines changed

18 files changed

+144
-52
lines changed

docs/ABI/Mangling.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Types
582582
type-list ::= empty-list
583583

584584
// FIXME: Consider replacing 'h' with a two-char code
585-
list-type ::= type identifier? 'z'? 'h'? 'n'? 'd'? // type with optional label, inout convention, shared convention, owned convention, and variadic specifier
585+
list-type ::= type identifier? 'k'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier
586586

587587
METATYPE-REPR ::= 't' // Thin metatype representation
588588
METATYPE-REPR ::= 'T' // Thick metatype representation
@@ -666,7 +666,7 @@ mangled in to disambiguate.
666666
COROUTINE-KIND ::= 'A' // yield-once coroutine
667667
COROUTINE-KIND ::= 'G' // yield-many coroutine
668668

669-
SENDABLE ::= 'h' // @Sendable
669+
SENDABLE ::= 'h' // @Sendable
670670
ASYNC ::= 'H' // @async
671671

672672
PARAM-CONVENTION ::= 'i' // indirect in

include/swift/AST/Attr.h

-2
Original file line numberDiff line numberDiff line change
@@ -2324,8 +2324,6 @@ class TypeAttributes {
23242324

23252325
Optional<Convention> ConventionArguments;
23262326

2327-
// Indicates whether the type's '@differentiable' attribute has a 'linear'
2328-
// argument.
23292327
DifferentiabilityKind differentiabilityKind =
23302328
DifferentiabilityKind::NonDifferentiable;
23312329

include/swift/AST/Types.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1929,7 +1929,7 @@ class ParameterTypeFlags {
19291929
NonEphemeral = 1 << 2,
19301930
OwnershipShift = 3,
19311931
Ownership = 7 << OwnershipShift,
1932-
NoDerivative = 1 << 7,
1932+
NoDerivative = 1 << 6,
19331933
NumBits = 7
19341934
};
19351935
OptionSet<ParameterFlags> value;

include/swift/Demangling/DemangleNodes.def

+1
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ NODE(AutoDiffSelfReorderingReabstractionThunk)
312312
NODE(AutoDiffSubsetParametersThunk)
313313
NODE(AutoDiffDerivativeVTableThunk)
314314
NODE(DifferentiabilityWitness)
315+
NODE(NoDerivative)
315316
NODE(IndexSubset)
316317
NODE(AsyncAwaitResumePartialFunction)
317318
NODE(AsyncSuspendResumePartialFunction)

include/swift/Demangling/TypeDecoder.h

+30-18
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class FunctionParam {
7272
void setValueOwnership(ValueOwnership ownership) {
7373
Flags = Flags.withValueOwnership(ownership);
7474
}
75+
void setNoDerivative() { Flags = Flags.withNoDerivative(true); }
7576
void setFlags(ParameterFlags flags) { Flags = flags; };
7677

7778
FunctionParam withLabel(StringRef label) const {
@@ -1375,28 +1376,39 @@ class TypeDecoder {
13751376
node = node->getFirstChild();
13761377
hasParamFlags = true;
13771378
};
1378-
switch (node->getKind()) {
1379-
case NodeKind::InOut:
1380-
setOwnership(ValueOwnership::InOut);
1381-
break;
13821379

1383-
case NodeKind::Shared:
1384-
setOwnership(ValueOwnership::Shared);
1385-
break;
1380+
bool recurse = true;
1381+
while (recurse) {
1382+
switch (node->getKind()) {
1383+
case NodeKind::InOut:
1384+
setOwnership(ValueOwnership::InOut);
1385+
break;
13861386

1387-
case NodeKind::Owned:
1388-
setOwnership(ValueOwnership::Owned);
1389-
break;
1387+
case NodeKind::Shared:
1388+
setOwnership(ValueOwnership::Shared);
1389+
break;
13901390

1391-
case NodeKind::AutoClosureType:
1392-
case NodeKind::EscapingAutoClosureType: {
1393-
param.setAutoClosure();
1394-
hasParamFlags = true;
1395-
break;
1396-
}
1391+
case NodeKind::Owned:
1392+
setOwnership(ValueOwnership::Owned);
1393+
break;
13971394

1398-
default:
1399-
break;
1395+
case NodeKind::NoDerivative:
1396+
param.setNoDerivative();
1397+
node = node->getFirstChild();
1398+
hasParamFlags = true;
1399+
break;
1400+
1401+
case NodeKind::AutoClosureType:
1402+
case NodeKind::EscapingAutoClosureType:
1403+
param.setAutoClosure();
1404+
hasParamFlags = true;
1405+
recurse = false;
1406+
break;
1407+
1408+
default:
1409+
recurse = false;
1410+
break;
1411+
}
14001412
}
14011413

14021414
auto paramType = decodeMangledType(node);

lib/AST/ASTMangler.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -2541,6 +2541,9 @@ void ASTMangler::appendTypeListElement(Identifier name, Type elementType,
25412541
else
25422542
appendType(elementType, forDecl);
25432543

2544+
if (flags.isNoDerivative()) {
2545+
appendOperator("k");
2546+
}
25442547
switch (flags.getValueOwnership()) {
25452548
case ValueOwnership::Default:
25462549
/* nothing */

lib/Demangling/Demangler.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,9 @@ NodePointer Demangler::demangleOperator() {
793793
popTypeAndGetChild()));
794794
case 'i': return demangleSubscript();
795795
case 'j': return demangleDifferentiableFunctionType();
796+
case 'k':
797+
return createType(
798+
createWithChild(Node::Kind::NoDerivative, popTypeAndGetChild()));
796799
case 'l': return demangleGenericSignature(/*hasParamCounts*/ false);
797800
case 'm': return createType(createWithChild(Node::Kind::Metatype,
798801
popNode(Node::Kind::Type)));

lib/Demangling/NodePrinter.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ class NodePrinter {
569569
case Node::Kind::AutoDiffSubsetParametersThunk:
570570
case Node::Kind::AutoDiffFunctionKind:
571571
case Node::Kind::DifferentiabilityWitness:
572+
case Node::Kind::NoDerivative:
572573
case Node::Kind::IndexSubset:
573574
case Node::Kind::AsyncAwaitResumePartialFunction:
574575
case Node::Kind::AsyncSuspendResumePartialFunction:
@@ -1421,6 +1422,10 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
14211422
Printer << "__owned ";
14221423
print(Node->getChild(0));
14231424
return nullptr;
1425+
case Node::Kind::NoDerivative:
1426+
Printer << "@noDerivative ";
1427+
print(Node->getChild(0));
1428+
return nullptr;
14241429
case Node::Kind::NonObjCAttribute:
14251430
Printer << "@nonobjc ";
14261431
return nullptr;

lib/Demangling/OldDemangler.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,14 @@ class OldDemangler {
20632063
inout->addChild(type, Factory);
20642064
return inout;
20652065
}
2066+
if (c == 'k') {
2067+
auto noDerivative = Factory.createNode(Node::Kind::NoDerivative);
2068+
auto type = demangleTypeImpl();
2069+
if (!type)
2070+
return nullptr;
2071+
noDerivative->addChild(type, Factory);
2072+
return noDerivative;
2073+
}
20662074
if (c == 'S') {
20672075
return demangleSubstitutionIndex();
20682076
}

lib/Demangling/OldRemangler.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,11 @@ void Remangler::mangleInOut(Node *node) {
14851485
mangleSingleChildNode(node); // type
14861486
}
14871487

1488+
void Remangler::mangleNoDerivative(Node *node) {
1489+
Buffer << 'k';
1490+
mangleSingleChildNode(node); // type
1491+
}
1492+
14881493
void Remangler::mangleTuple(Node *node) {
14891494
size_t NumElems = node->getNumChildren();
14901495
if (NumElems > 0 &&

lib/Demangling/Remangler.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,11 @@ void Remangler::mangleOwned(Node *node) {
16761676
Buffer << 'n';
16771677
}
16781678

1679+
void Remangler::mangleNoDerivative(Node *node) {
1680+
mangleSingleChildNode(node);
1681+
Buffer << 'k';
1682+
}
1683+
16791684
void Remangler::mangleInfixOperator(Node *node) {
16801685
mangleIdentifierImpl(node, /*isOperator*/ true);
16811686
Buffer << "oi";

lib/Sema/TypeCheckType.cpp

+28-21
Original file line numberDiff line numberDiff line change
@@ -2524,10 +2524,12 @@ TypeResolver::resolveAttributedType(TypeAttributes &attrs, TypeRepr *repr,
25242524
}
25252525

25262526
if (attrs.has(TAK_noDerivative)) {
2527-
// @noDerivative is only valid on function parameters, or on function
2528-
// results in SIL.
2527+
// @noDerivative is valid on function parameters (AST and SIL) or on
2528+
// function results (SIL-only).
25292529
bool isNoDerivativeAllowed =
2530-
isParam || (isResult && (options & TypeResolutionFlags::SILType));
2530+
isParam ||
2531+
options.is(TypeResolverContext::InoutFunctionInput) ||
2532+
(isResult && (options & TypeResolutionFlags::SILType));
25312533
auto *SF = getDeclContext()->getParentSourceFile();
25322534
if (SF && !isDifferentiableProgrammingEnabled(*SF)) {
25332535
diagnose(
@@ -2645,7 +2647,7 @@ TypeResolver::resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
26452647
if (auto *ATR = dyn_cast<AttributedTypeRepr>(eltTypeRepr))
26462648
autoclosure = ATR->getAttrs().has(TAK_autoclosure);
26472649

2648-
ValueOwnership ownership;
2650+
ValueOwnership ownership = ValueOwnership::Default;
26492651

26502652
auto *nestedRepr = eltTypeRepr;
26512653

@@ -2657,31 +2659,35 @@ TypeResolver::resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
26572659
nestedRepr = tupleRepr->getElementType(0);
26582660
}
26592661

2660-
switch (nestedRepr->getKind()) {
2661-
case TypeReprKind::Shared:
2662-
ownership = ValueOwnership::Shared;
2663-
break;
2664-
case TypeReprKind::InOut:
2665-
ownership = ValueOwnership::InOut;
2666-
break;
2667-
case TypeReprKind::Owned:
2668-
ownership = ValueOwnership::Owned;
2669-
break;
2670-
default:
2671-
ownership = ValueOwnership::Default;
2672-
break;
2662+
if (auto *specifierRepr = dyn_cast<SpecifierTypeRepr>(nestedRepr)) {
2663+
switch (specifierRepr->getKind()) {
2664+
case TypeReprKind::Shared:
2665+
ownership = ValueOwnership::Shared;
2666+
nestedRepr = specifierRepr->getBase();
2667+
break;
2668+
case TypeReprKind::InOut:
2669+
ownership = ValueOwnership::InOut;
2670+
nestedRepr = specifierRepr->getBase();
2671+
break;
2672+
case TypeReprKind::Owned:
2673+
ownership = ValueOwnership::Owned;
2674+
nestedRepr = specifierRepr->getBase();
2675+
break;
2676+
default:
2677+
break;
2678+
}
26732679
}
26742680

26752681
bool noDerivative = false;
2676-
if (auto *attrTypeRepr = dyn_cast<AttributedTypeRepr>(eltTypeRepr)) {
2682+
if (auto *attrTypeRepr = dyn_cast<AttributedTypeRepr>(nestedRepr)) {
26772683
if (attrTypeRepr->getAttrs().has(TAK_noDerivative)) {
26782684
if (diffKind == DifferentiabilityKind::NonDifferentiable &&
26792685
isDifferentiableProgrammingEnabled(
26802686
*getDeclContext()->getParentSourceFile()))
2681-
diagnose(eltTypeRepr->getLoc(),
2687+
diagnose(nestedRepr->getLoc(),
26822688
diag::attr_only_on_parameters_of_differentiable,
26832689
"@noDerivative")
2684-
.highlight(eltTypeRepr->getSourceRange());
2690+
.highlight(nestedRepr->getSourceRange());
26852691
else
26862692
noDerivative = true;
26872693
}
@@ -3476,7 +3482,7 @@ TypeResolver::resolveSpecifierTypeRepr(SpecifierTypeRepr *repr,
34763482
if (isa<InOutTypeRepr>(repr)
34773483
&& !isa<ImplicitlyUnwrappedOptionalTypeRepr>(repr->getBase())) {
34783484
// Anything within an inout isn't a parameter anymore.
3479-
options.setContext(None);
3485+
options.setContext(TypeResolverContext::InoutFunctionInput);
34803486
}
34813487

34823488
return resolveType(repr->getBase(), options);
@@ -3554,6 +3560,7 @@ NeverNullType TypeResolver::resolveImplicitlyUnwrappedOptionalType(
35543560
bool doDiag = false;
35553561
switch (options.getContext()) {
35563562
case TypeResolverContext::None:
3563+
case TypeResolverContext::InoutFunctionInput:
35573564
if (!isDirect || !(options & allowIUO))
35583565
doDiag = true;
35593566
break;

lib/Sema/TypeCheckType.h

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ enum class TypeResolverContext : uint8_t {
8585
/// Whether this is a variadic function input.
8686
VariadicFunctionInput,
8787

88+
/// Whether this is an 'inout' function input.
89+
InoutFunctionInput,
90+
8891
/// Whether we are in the result type of a function, including multi-level
8992
/// tuple return values. See also: TypeResolutionFlags::Direct
9093
FunctionResult,
@@ -198,6 +201,7 @@ class TypeResolutionOptions {
198201
case Context::None:
199202
case Context::FunctionInput:
200203
case Context::VariadicFunctionInput:
204+
case Context::InoutFunctionInput:
201205
case Context::FunctionResult:
202206
case Context::ExtensionBinding:
203207
case Context::SubscriptDecl:

stdlib/public/Reflection/TypeRef.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,9 @@ class DemanglingForTypeRef
582582
parent->addChild(input, Dem);
583583
input = parent;
584584
};
585+
if (flags.isNoDerivative()) {
586+
wrapInput(Node::Kind::NoDerivative);
587+
}
585588
switch (flags.getValueOwnership()) {
586589
case ValueOwnership::Default:
587590
/* nothing */

stdlib/public/runtime/Demangle.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,9 @@ swift::_swift_buildDemanglingForMetadata(const Metadata *type,
482482
parent->addChild(input, Dem);
483483
input = parent;
484484
};
485+
if (flags.isNoDerivative()) {
486+
wrapInput(Node::Kind::NoDerivative);
487+
}
485488
switch (flags.getValueOwnership()) {
486489
case ValueOwnership::Default:
487490
/* nothing */

test/AutoDiff/validation-test/function_type_metadata.swift

+24-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,30 @@ if #available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) {
2626
Swift.Optional<Swift.Float>
2727
""",
2828
String(reflecting: (@differentiable(reverse) (Float?) -> Float?).self))
29-
// FIXME(rdar://75916833): Mangle '@noDerivative' in function types.
30-
// expectEqual(
31-
// """
32-
// @differentiable(reverse) (Swift.Optional<Swift.Float>, \
33-
// @noDerivative Swift.Int) -> Swift.Optional<Swift.Float>
34-
// """,
35-
// String(reflecting: (
36-
// @differentiable(reverse) (Float?, @noDerivative Int) -> Float?).self))
29+
expectEqual(
30+
"""
31+
@differentiable(reverse) (Swift.Optional<Swift.Float>, \
32+
@noDerivative Swift.Int) -> Swift.Optional<Swift.Float>
33+
""",
34+
String(reflecting: (
35+
@differentiable(reverse)
36+
(Float?, @noDerivative Int) -> Float?).self))
37+
expectEqual(
38+
"""
39+
@differentiable(reverse) (Swift.Optional<Swift.Float>, \
40+
__owned @noDerivative Swift.Int) -> Swift.Optional<Swift.Float>
41+
""",
42+
String(reflecting: (
43+
@differentiable(reverse)
44+
(Float?, __owned @noDerivative Int) -> Float?).self))
45+
expectEqual(
46+
"""
47+
@differentiable(reverse) (Swift.Optional<Swift.Float>, \
48+
inout @noDerivative Swift.Int) -> Swift.Optional<Swift.Float>
49+
""",
50+
String(reflecting: (
51+
@differentiable(reverse)
52+
(Float?, inout @noDerivative Int) -> Float?).self))
3753
}
3854
}
3955

test/Demangle/Inputs/manglings.txt

+2
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,5 @@ $s4diff1hyyS2ijfXEF ---> diff.h(@differentiable(_forward) (Swift.Int) -> Swift.I
399399
$s4diff1hyyS2ijrXEF ---> diff.h(@differentiable(reverse) (Swift.Int) -> Swift.Int) -> ()
400400
$s4diff1hyyS2ijdXEF ---> diff.h(@differentiable (Swift.Int) -> Swift.Int) -> ()
401401
$s4diff1hyyS2ijlXEF ---> diff.h(@differentiable(_linear) (Swift.Int) -> Swift.Int) -> ()
402+
$s4test3fooyyS2f_SfkztjrXEF ---> test.foo(@differentiable(reverse) (Swift.Float, inout @noDerivative Swift.Float) -> Swift.Float) -> ()
403+
$s4test3fooyyS2f_SfktjrXEF ---> test.foo(@differentiable(reverse) (Swift.Float, @noDerivative Swift.Float) -> Swift.Float) -> ()

unittests/Reflection/TypeRef.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,23 @@ TEST(TypeRefTest, UniqueFunctionTypeRef) {
250250
FunctionMetadataDifferentiabilityKind::Reverse);
251251
EXPECT_EQ(F17, F18);
252252
EXPECT_NE(F17, F19);
253+
254+
// Test differentiable with @noDerivative.
255+
{
256+
auto parameters = Parameters1;
257+
parameters[1].setNoDerivative();
258+
auto f1 = Builder.createFunctionType(
259+
parameters, Result, FunctionTypeFlags().withDifferentiable(true),
260+
FunctionMetadataDifferentiabilityKind::Reverse);
261+
auto f2 = Builder.createFunctionType(
262+
parameters, Result, FunctionTypeFlags().withDifferentiable(true),
263+
FunctionMetadataDifferentiabilityKind::Reverse);
264+
auto f3 = Builder.createFunctionType(
265+
Parameters1, Result, FunctionTypeFlags().withDifferentiable(true),
266+
FunctionMetadataDifferentiabilityKind::Reverse);
267+
EXPECT_EQ(f1, f2);
268+
EXPECT_NE(f1, f3);
269+
}
253270
}
254271

255272
TEST(TypeRefTest, UniqueProtocolTypeRef) {

0 commit comments

Comments
 (0)