Skip to content

Commit ac72084

Browse files
authored
Merge pull request #63683 from apple/egorzhdan/cxx-optional
[cxx-interop] Add `CxxOptional` protocol for `std::optional` ergonomics
2 parents eb3cbde + a12986a commit ac72084

File tree

12 files changed

+141
-12
lines changed

12 files changed

+141
-12
lines changed

Diff for: include/swift/AST/KnownProtocols.def

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
108108
// C++ Standard Library Overlay:
109109
PROTOCOL(CxxConvertibleToCollection)
110110
PROTOCOL(CxxDictionary)
111+
PROTOCOL(CxxOptional)
111112
PROTOCOL(CxxPair)
112113
PROTOCOL(CxxSet)
113114
PROTOCOL(CxxRandomAccessCollection)

Diff for: lib/AST/ASTContext.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11331133
case KnownProtocolKind::CxxConvertibleToCollection:
11341134
case KnownProtocolKind::CxxDictionary:
11351135
case KnownProtocolKind::CxxPair:
1136+
case KnownProtocolKind::CxxOptional:
11361137
case KnownProtocolKind::CxxRandomAccessCollection:
11371138
case KnownProtocolKind::CxxSet:
11381139
case KnownProtocolKind::CxxSequence:

Diff for: lib/ClangImporter/ClangDerivedConformances.cpp

+42-10
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ static bool isConcreteAndValid(ProtocolConformanceRef conformanceRef,
8686
});
8787
}
8888

89+
static bool isStdDecl(const clang::CXXRecordDecl *clangDecl,
90+
llvm::ArrayRef<StringRef> names) {
91+
if (!clangDecl->isInStdNamespace())
92+
return false;
93+
if (!clangDecl->getIdentifier())
94+
return false;
95+
StringRef name = clangDecl->getName();
96+
return llvm::is_contained(names, name);
97+
}
98+
8999
static clang::TypeDecl *
90100
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
91101
clang::IdentifierInfo *iteratorCategoryDeclName =
@@ -381,6 +391,38 @@ void swift::conformToCxxIteratorIfNeeded(
381391
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
382392
}
383393

394+
void swift::conformToCxxOptionalIfNeeded(
395+
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
396+
const clang::CXXRecordDecl *clangDecl) {
397+
PrettyStackTraceDecl trace("conforming to CxxOptional", decl);
398+
399+
assert(decl);
400+
assert(clangDecl);
401+
ASTContext &ctx = decl->getASTContext();
402+
403+
if (!isStdDecl(clangDecl, {"optional"}))
404+
return;
405+
406+
ProtocolDecl *cxxOptionalProto =
407+
ctx.getProtocol(KnownProtocolKind::CxxOptional);
408+
// If the Cxx module is missing, or does not include one of the necessary
409+
// protocol, bail.
410+
if (!cxxOptionalProto)
411+
return;
412+
413+
auto pointeeId = ctx.getIdentifier("pointee");
414+
auto pointees = lookupDirectWithoutExtensions(decl, pointeeId);
415+
if (pointees.size() != 1)
416+
return;
417+
auto pointee = dyn_cast<VarDecl>(pointees.front());
418+
if (!pointee)
419+
return;
420+
auto pointeeTy = pointee->getInterfaceType();
421+
422+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Wrapped"), pointeeTy);
423+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxOptional});
424+
}
425+
384426
void swift::conformToCxxSequenceIfNeeded(
385427
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
386428
const clang::CXXRecordDecl *clangDecl) {
@@ -523,16 +565,6 @@ void swift::conformToCxxSequenceIfNeeded(
523565
}
524566
}
525567

526-
static bool isStdDecl(const clang::CXXRecordDecl *clangDecl,
527-
llvm::ArrayRef<StringRef> names) {
528-
if (!clangDecl->isInStdNamespace())
529-
return false;
530-
if (!clangDecl->getIdentifier())
531-
return false;
532-
StringRef name = clangDecl->getName();
533-
return llvm::is_contained(names, name);
534-
}
535-
536568
void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
537569
NominalTypeDecl *decl,
538570
const clang::CXXRecordDecl *clangDecl) {

Diff for: lib/ClangImporter/ClangDerivedConformances.h

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl,
2626
NominalTypeDecl *decl,
2727
const clang::CXXRecordDecl *clangDecl);
2828

29+
/// If the decl is an instantiation of C++ `std::optional`, synthesize a
30+
/// conformance to CxxOptional protocol, which is defined in the Cxx module.
31+
void conformToCxxOptionalIfNeeded(ClangImporter::Implementation &impl,
32+
NominalTypeDecl *decl,
33+
const clang::CXXRecordDecl *clangDecl);
34+
2935
/// If the decl is a C++ sequence, synthesize a conformance to the CxxSequence
3036
/// protocol, which is defined in the Cxx module.
3137
void conformToCxxSequenceIfNeeded(ClangImporter::Implementation &impl,

Diff for: lib/ClangImporter/ImportDecl.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2633,6 +2633,7 @@ namespace {
26332633
conformToCxxSetIfNeeded(Impl, nominalDecl, decl);
26342634
conformToCxxDictionaryIfNeeded(Impl, nominalDecl, decl);
26352635
conformToCxxPairIfNeeded(Impl, nominalDecl, decl);
2636+
conformToCxxOptionalIfNeeded(Impl, nominalDecl, decl);
26362637
}
26372638

26382639
if (auto *ntd = dyn_cast<NominalTypeDecl>(result))

Diff for: lib/ClangImporter/ImportName.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ static bool isErrorOutParameter(const clang::ParmVarDecl *param,
155155

156156
static bool isBoolType(clang::ASTContext &ctx, clang::QualType type) {
157157
do {
158+
if (type->isBooleanType())
159+
return true;
160+
158161
// Check whether we have a typedef for "BOOL" or "Boolean".
159162
if (auto typedefType = dyn_cast<clang::TypedefType>(type.getTypePtr())) {
160163
auto typedefDecl = typedefType->getDecl();
@@ -1848,7 +1851,20 @@ ImportedName NameImporter::importNameImpl(const clang::NamedDecl *D,
18481851
break;
18491852
}
18501853

1851-
case clang::DeclarationName::CXXConversionFunctionName:
1854+
case clang::DeclarationName::CXXConversionFunctionName: {
1855+
auto conversionDecl = dyn_cast<clang::CXXConversionDecl>(D);
1856+
if (!conversionDecl)
1857+
return ImportedName();
1858+
auto toType = conversionDecl->getConversionType();
1859+
// Only import `operator bool()` for now.
1860+
if (isBoolType(clangSema.Context, toType)) {
1861+
isFunction = true;
1862+
baseName = "__convertToBool";
1863+
addEmptyArgNamesForClangFunction(conversionDecl, argumentNames);
1864+
break;
1865+
}
1866+
return ImportedName();
1867+
}
18521868
case clang::DeclarationName::CXXDestructorName:
18531869
case clang::DeclarationName::CXXLiteralOperatorName:
18541870
case clang::DeclarationName::CXXUsingDirective:

Diff for: lib/IRGen/GenMeta.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6166,6 +6166,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
61666166
case KnownProtocolKind::CxxConvertibleToCollection:
61676167
case KnownProtocolKind::CxxDictionary:
61686168
case KnownProtocolKind::CxxPair:
6169+
case KnownProtocolKind::CxxOptional:
61696170
case KnownProtocolKind::CxxRandomAccessCollection:
61706171
case KnownProtocolKind::CxxSet:
61716172
case KnownProtocolKind::CxxSequence:

Diff for: stdlib/public/Cxx/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_swift_target_library(swiftCxx ${SWIFT_CXX_LIBRARY_KIND} NO_LINK_NAME IS_STDL
77
CxxConvertibleToCollection.swift
88
CxxDictionary.swift
99
CxxPair.swift
10+
CxxOptional.swift
1011
CxxSet.swift
1112
CxxRandomAccessCollection.swift
1213
CxxSequence.swift

Diff for: stdlib/public/Cxx/CxxOptional.swift

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
public protocol CxxOptional<Wrapped> {
14+
associatedtype Wrapped
15+
16+
func __convertToBool() -> Bool
17+
18+
var pointee: Wrapped { get }
19+
}
20+
21+
extension CxxOptional {
22+
@inlinable
23+
public var hasValue: Bool {
24+
get {
25+
return __convertToBool()
26+
}
27+
}
28+
29+
@inlinable
30+
public var value: Wrapped {
31+
get {
32+
return pointee
33+
}
34+
}
35+
}
36+
37+
extension Optional {
38+
@inlinable
39+
public init(fromCxx value: some CxxOptional<Wrapped>) {
40+
guard value.__convertToBool() else {
41+
self = nil
42+
return
43+
}
44+
self = value.pointee
45+
}
46+
}

Diff for: test/Interop/Cxx/operators/Inputs/member-inline.h

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct LoadableIntWrapper {
2222
return value + x * y;
2323
}
2424

25+
operator int() const { return value; }
26+
2527
LoadableIntWrapper &operator++() {
2628
value++;
2729
return *this;
@@ -48,6 +50,7 @@ struct LoadableBoolWrapper {
4850
LoadableBoolWrapper operator!() {
4951
return LoadableBoolWrapper{.value = !value};
5052
}
53+
operator bool() const { return value; }
5154
};
5255

5356
template<class T>

Diff for: test/Interop/Cxx/operators/member-inline-module-interface.swift

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
// CHECK: struct LoadableBoolWrapper {
1313
// CHECK: prefix static func ! (lhs: inout LoadableBoolWrapper) -> LoadableBoolWrapper
14+
// CHECK: func __convertToBool() -> Bool
1415
// CHECK: }
1516

1617
// CHECK: struct AddressOnlyIntWrapper {

Diff for: test/Interop/Cxx/stdlib/use-std-optional.swift

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-I %S/Inputs -Xfrontend -enable-experimental-cxx-interop -Xcc -std=c++17)
1+
// RUN: %target-run-simple-swift(-I %S/Inputs -Xfrontend -enable-experimental-cxx-interop -Xcc -std=c++17 -Xfrontend -validate-tbd-against-ir=none)
22
//
33
// REQUIRES: executable_test
44
// REQUIRES: OS=macosx
@@ -15,4 +15,24 @@ StdOptionalTestSuite.test("pointee") {
1515
expectEqual(123, pointee)
1616
}
1717

18+
StdOptionalTestSuite.test("std::optional => Swift.Optional") {
19+
let nonNilOpt = getNonNilOptional()
20+
let swiftOptional = Optional(fromCxx: nonNilOpt)
21+
expectNotNil(swiftOptional)
22+
expectEqual(123, swiftOptional!)
23+
24+
let nilOpt = getNilOptional()
25+
let swiftNil = Optional(fromCxx: nilOpt)
26+
expectNil(swiftNil)
27+
}
28+
29+
StdOptionalTestSuite.test("std::optional hasValue/value") {
30+
let nonNilOpt = getNonNilOptional()
31+
expectTrue(nonNilOpt.hasValue)
32+
expectEqual(123, nonNilOpt.value)
33+
34+
let nilOpt = getNilOptional()
35+
expectFalse(nilOpt.hasValue)
36+
}
37+
1838
runAllTests()

0 commit comments

Comments
 (0)