Skip to content

Commit 39b8b3c

Browse files
authored
Merge pull request #76106 from swiftlang/egorzhdan/cxx-mutable-rac
[cxx-interop] Add `CxxMutableRandomAccessCollection` protocol
2 parents 44dbebd + 0ab6815 commit 39b8b3c

12 files changed

+171
-25
lines changed

include/swift/AST/KnownProtocols.def

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ PROTOCOL(CxxOptional)
132132
PROTOCOL(CxxPair)
133133
PROTOCOL(CxxSet)
134134
PROTOCOL(CxxRandomAccessCollection)
135+
PROTOCOL(CxxMutableRandomAccessCollection)
135136
PROTOCOL(CxxSequence)
136137
PROTOCOL(CxxUniqueSet)
137138
PROTOCOL(CxxVector)

lib/AST/ASTContext.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
14211421
case KnownProtocolKind::CxxPair:
14221422
case KnownProtocolKind::CxxOptional:
14231423
case KnownProtocolKind::CxxRandomAccessCollection:
1424+
case KnownProtocolKind::CxxMutableRandomAccessCollection:
14241425
case KnownProtocolKind::CxxSet:
14251426
case KnownProtocolKind::CxxSequence:
14261427
case KnownProtocolKind::CxxUniqueSet:

lib/ClangImporter/ClangDerivedConformances.cpp

+38-2
Original file line numberDiff line numberDiff line change
@@ -780,8 +780,44 @@ void swift::conformToCxxSequenceIfNeeded(
780780
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Indices"), indicesTy);
781781
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("SubSequence"),
782782
sliceTy);
783-
impl.addSynthesizedProtocolAttrs(
784-
decl, {KnownProtocolKind::CxxRandomAccessCollection});
783+
784+
auto tryToConformToMutatingRACollection = [&]() -> bool {
785+
auto rawMutableIteratorProto = ctx.getProtocol(
786+
KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator);
787+
if (!rawMutableIteratorProto)
788+
return false;
789+
790+
// Check if present: `func __beginMutatingUnsafe() -> RawMutableIterator`
791+
auto beginMutatingId = ctx.getIdentifier("__beginMutatingUnsafe");
792+
auto beginMutating =
793+
lookupDirectSingleWithoutExtensions<FuncDecl>(decl, beginMutatingId);
794+
if (!beginMutating)
795+
return false;
796+
auto rawMutableIteratorTy = beginMutating->getResultInterfaceType();
797+
798+
// Check if present: `func __endMutatingUnsafe() -> RawMutableIterator`
799+
auto endMutatingId = ctx.getIdentifier("__endMutatingUnsafe");
800+
auto endMutating =
801+
lookupDirectSingleWithoutExtensions<FuncDecl>(decl, endMutatingId);
802+
if (!endMutating)
803+
return false;
804+
805+
if (!checkConformance(rawMutableIteratorTy, rawMutableIteratorProto))
806+
return false;
807+
808+
impl.addSynthesizedTypealias(
809+
decl, ctx.getIdentifier("RawMutableIterator"), rawMutableIteratorTy);
810+
impl.addSynthesizedProtocolAttrs(
811+
decl, {KnownProtocolKind::CxxMutableRandomAccessCollection});
812+
return true;
813+
};
814+
815+
bool conformedToMutableRAC = tryToConformToMutatingRACollection();
816+
817+
if (!conformedToMutableRAC)
818+
impl.addSynthesizedProtocolAttrs(
819+
decl, {KnownProtocolKind::CxxRandomAccessCollection});
820+
785821
return true;
786822
};
787823

lib/IRGen/GenMeta.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6883,6 +6883,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
68836883
case KnownProtocolKind::CxxPair:
68846884
case KnownProtocolKind::CxxOptional:
68856885
case KnownProtocolKind::CxxRandomAccessCollection:
6886+
case KnownProtocolKind::CxxMutableRandomAccessCollection:
68866887
case KnownProtocolKind::CxxSet:
68876888
case KnownProtocolKind::CxxSequence:
68886889
case KnownProtocolKind::CxxUniqueSet:

stdlib/public/Cxx/CxxRandomAccessCollection.swift

+41-5
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,52 @@ extension CxxRandomAccessCollection {
3636
return Int(__endUnsafe() - __beginUnsafe())
3737
}
3838

39+
@inlinable
40+
@inline(__always)
41+
internal func _getRawIterator(at index: Int) -> RawIterator {
42+
var rawIterator = self.__beginUnsafe()
43+
rawIterator += RawIterator.Distance(index)
44+
precondition(self.__endUnsafe() - rawIterator > 0,
45+
"C++ iterator access out of bounds")
46+
return rawIterator
47+
}
48+
3949
/// A C++ implementation of the subscript might be more performant. This
4050
/// overload should only be used if the C++ type does not define `operator[]`.
4151
@inlinable
4252
public subscript(_ index: Int) -> Element {
4353
_read {
44-
// Not using CxxIterator here to avoid making a copy of the collection.
45-
var rawIterator = __beginUnsafe()
46-
rawIterator += RawIterator.Distance(index)
47-
precondition(__endUnsafe() - rawIterator > 0, "C++ iterator access out of bounds")
48-
yield rawIterator.pointee
54+
yield self._getRawIterator(at: index).pointee
55+
}
56+
}
57+
}
58+
59+
public protocol CxxMutableRandomAccessCollection<Element>:
60+
CxxRandomAccessCollection, MutableCollection {
61+
associatedtype RawMutableIterator: UnsafeCxxMutableRandomAccessIterator
62+
where RawMutableIterator.Pointee == Element
63+
64+
/// Do not implement this function manually in Swift.
65+
mutating func __beginMutatingUnsafe() -> RawMutableIterator
66+
67+
/// Do not implement this function manually in Swift.
68+
mutating func __endMutatingUnsafe() -> RawMutableIterator
69+
}
70+
71+
extension CxxMutableRandomAccessCollection {
72+
/// A C++ implementation of the subscript might be more performant. This
73+
/// overload should only be used if the C++ type does not define `operator[]`.
74+
@inlinable
75+
public subscript(_ index: Int) -> Element {
76+
_read {
77+
yield self._getRawIterator(at: index).pointee
78+
}
79+
_modify {
80+
var rawIterator = self.__beginMutatingUnsafe()
81+
rawIterator += RawMutableIterator.Distance(index)
82+
precondition(self.__endMutatingUnsafe() - rawIterator > 0,
83+
"C++ iterator access out of bounds")
84+
yield &rawIterator.pointee
4985
}
5086
}
5187
}

test/Interop/Cxx/stdlib/libcxx-module-interface.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020

2121
// CHECK-IOSFWD: enum std {
2222
// CHECK-IOSFWD: enum __1 {
23-
// CHECK-IOSFWD: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxRandomAccessCollection {
23+
// CHECK-IOSFWD: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxMutableRandomAccessCollection {
2424
// CHECK-IOSFWD: typealias value_type = CChar
2525
// CHECK-IOSFWD: }
26-
// CHECK-IOSFWD: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxRandomAccessCollection {
26+
// CHECK-IOSFWD: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxMutableRandomAccessCollection {
2727
// CHECK-IOSFWD: typealias value_type = CWideChar
2828
// CHECK-IOSFWD: }
2929
// CHECK-IOSFWD: typealias string = std.__1.basic_string<CChar, char_traits<CChar>, allocator<CChar>>

test/Interop/Cxx/stdlib/libstdcxx-module-interface.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
// REQUIRES: OS=linux-gnu
1212

1313
// CHECK-STD: enum std {
14-
// CHECK-STRING: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxRandomAccessCollection {
14+
// CHECK-STRING: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxMutableRandomAccessCollection {
1515
// CHECK-STRING: typealias value_type = std.char_traits<CChar>.char_type
1616
// CHECK-STRING: }
17-
// CHECK-STRING: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxRandomAccessCollection {
17+
// CHECK-STRING: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxMutableRandomAccessCollection {
1818
// CHECK-STRING: typealias value_type = std.char_traits<CWideChar>.char_type
1919
// CHECK-STRING: }
2020

test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h

+21-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ struct SimpleCollectionNoSubscript {
1111
public:
1212
using iterator = ConstRACIterator;
1313

14-
iterator begin() const { return iterator(*x); }
15-
iterator end() const { return iterator(*x + 5); }
14+
iterator begin() const { return iterator(x); }
15+
iterator end() const { return iterator(x + 5); }
1616
};
1717

1818
struct SimpleCollectionReadOnly {
@@ -22,12 +22,29 @@ struct SimpleCollectionReadOnly {
2222
public:
2323
using iterator = ConstRACIteratorRefPlusEq;
2424

25-
iterator begin() const { return iterator(*x); }
26-
iterator end() const { return iterator(*x + 5); }
25+
iterator begin() const { return iterator(x); }
26+
iterator end() const { return iterator(x + 5); }
2727

2828
const int& operator[](int index) const { return x[index]; }
2929
};
3030

31+
struct SimpleCollectionReadWrite {
32+
private:
33+
int x[5] = {1, 2, 3, 4, 5};
34+
35+
public:
36+
using const_iterator = ConstRACIterator;
37+
using iterator = MutableRACIterator;
38+
39+
const_iterator begin() const { return const_iterator(x); }
40+
const_iterator end() const { return const_iterator(x + 5); }
41+
iterator begin() { return iterator(x); }
42+
iterator end() { return iterator(x + 5); }
43+
44+
const int &operator[](int index) const { return x[index]; }
45+
int &operator[](int index) { return x[index]; }
46+
};
47+
3148
template <typename T>
3249
struct HasInheritedTemplatedConstRACIterator {
3350
public:

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct ConstIterator {
4242

4343
struct ConstRACIterator {
4444
private:
45-
int value;
45+
const int *value;
4646

4747
public:
4848
using iterator_category = std::random_access_iterator_tag;
@@ -51,10 +51,10 @@ struct ConstRACIterator {
5151
using reference = const int &;
5252
using difference_type = int;
5353

54-
ConstRACIterator(int value) : value(value) {}
54+
ConstRACIterator(const int *value) : value(value) {}
5555
ConstRACIterator(const ConstRACIterator &other) = default;
5656

57-
const int &operator*() const { return value; }
57+
const int &operator*() const { return *value; }
5858

5959
ConstRACIterator &operator++() {
6060
value++;
@@ -97,7 +97,7 @@ struct ConstRACIterator {
9797
// Same as ConstRACIterator, but operator+= returns a reference to this.
9898
struct ConstRACIteratorRefPlusEq {
9999
private:
100-
int value;
100+
const int *value;
101101

102102
public:
103103
using iterator_category = std::random_access_iterator_tag;
@@ -106,10 +106,10 @@ struct ConstRACIteratorRefPlusEq {
106106
using reference = const int &;
107107
using difference_type = int;
108108

109-
ConstRACIteratorRefPlusEq(int value) : value(value) {}
109+
ConstRACIteratorRefPlusEq(const int *value) : value(value) {}
110110
ConstRACIteratorRefPlusEq(const ConstRACIteratorRefPlusEq &other) = default;
111111

112-
const int &operator*() const { return value; }
112+
const int &operator*() const { return *value; }
113113

114114
ConstRACIteratorRefPlusEq &operator++() {
115115
value++;
@@ -918,7 +918,7 @@ struct InputOutputConstIterator {
918918

919919
struct MutableRACIterator {
920920
private:
921-
int value;
921+
int *value;
922922

923923
public:
924924
struct iterator_category : std::random_access_iterator_tag,
@@ -928,11 +928,11 @@ struct MutableRACIterator {
928928
using reference = const int &;
929929
using difference_type = int;
930930

931-
MutableRACIterator(int value) : value(value) {}
931+
MutableRACIterator(int *value) : value(value) {}
932932
MutableRACIterator(const MutableRACIterator &other) = default;
933933

934-
const int &operator*() const { return value; }
935-
int &operator*() { return value; }
934+
const int &operator*() const { return *value; }
935+
int &operator*() { return *value; }
936936

937937
MutableRACIterator &operator++() {
938938
value++;

test/Interop/Cxx/stdlib/overlay/custom-collection-module-interface.swift

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
// CHECK: typealias RawIterator = SimpleCollectionReadOnly.iterator
2121
// CHECK: }
2222

23+
// CHECK: struct SimpleCollectionReadWrite : CxxMutableRandomAccessCollection {
24+
// CHECK: typealias Element = ConstRACIterator.Pointee
25+
// CHECK: typealias Iterator = CxxIterator<SimpleCollectionReadWrite>
26+
// CHECK: typealias RawIterator = SimpleCollectionReadWrite.const_iterator
27+
// CHECK: typealias RawMutableIterator = SimpleCollectionReadWrite.iterator
28+
// CHECK: }
29+
2330
// CHECK: struct HasInheritedTemplatedConstRACIterator<CInt> : CxxRandomAccessCollection {
2431
// CHECK: typealias Element = InheritedTemplatedConstRACIterator<CInt>.Pointee
2532
// CHECK: typealias Iterator = CxxIterator<HasInheritedTemplatedConstRACIterator<CInt>>

test/Interop/Cxx/stdlib/overlay/custom-collection.swift

+17
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@ CxxCollectionTestSuite.test("SimpleCollectionReadOnly as Swift.Collection") {
3737
expectEqual(slice.last, 3)
3838
}
3939

40+
CxxCollectionTestSuite.test("SimpleCollectionReadWrite as Swift.MutableCollection") {
41+
var c = SimpleCollectionReadWrite()
42+
expectEqual(c.first, 1)
43+
expectEqual(c.last, 5)
44+
45+
c.swapAt(0, 4)
46+
expectEqual(c.first, 5)
47+
expectEqual(c.last, 1)
48+
49+
c.reverse()
50+
expectEqual(c[0], 1)
51+
expectEqual(c[1], 4)
52+
expectEqual(c[2], 3)
53+
expectEqual(c[3], 2)
54+
expectEqual(c[4], 5)
55+
}
56+
4057
CxxCollectionTestSuite.test("SimpleArrayWrapper as Swift.Collection") {
4158
let c = SimpleArrayWrapper()
4259
expectEqual(c.first, 10)

test/Interop/Cxx/stdlib/use-std-vector.swift

+30
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,36 @@ StdVectorTestSuite.test("VectorOfInt as ExpressibleByArrayLiteral") {
7070
expectEqual(v2[2], 3)
7171
}
7272

73+
#if !os(Windows) // FIXME: rdar://113704853
74+
StdVectorTestSuite.test("VectorOfInt as MutableCollection") {
75+
var v = Vector([2, 3, 1])
76+
v.sort() // Swift function
77+
expectEqual(v[0], 1)
78+
expectEqual(v[1], 2)
79+
expectEqual(v[2], 3)
80+
81+
v.reverse() // Swift function
82+
expectEqual(v[0], 3)
83+
expectEqual(v[1], 2)
84+
expectEqual(v[2], 1)
85+
}
86+
87+
StdVectorTestSuite.test("VectorOfString as MutableCollection") {
88+
var v = VectorOfString([std.string("xyz"),
89+
std.string("abc"),
90+
std.string("ijk")])
91+
v.swapAt(0, 2) // Swift function
92+
expectEqual(v[0], std.string("ijk"))
93+
expectEqual(v[1], std.string("abc"))
94+
expectEqual(v[2], std.string("xyz"))
95+
96+
v.reverse() // Swift function
97+
expectEqual(v[0], std.string("xyz"))
98+
expectEqual(v[1], std.string("abc"))
99+
expectEqual(v[2], std.string("ijk"))
100+
}
101+
#endif
102+
73103
StdVectorTestSuite.test("VectorOfInt.push_back") {
74104
var v = Vector()
75105
let _42: CInt = 42

0 commit comments

Comments
 (0)