|
12 | 12 |
|
13 | 13 | #include "RefactoringActions.h"
|
14 | 14 | #include "Utils.h"
|
| 15 | +#include "swift/AST/ProtocolConformance.h" |
15 | 16 |
|
16 | 17 | using namespace swift::refactoring;
|
17 | 18 |
|
18 | 19 | namespace {
|
19 | 20 | class AddCodableContext {
|
20 | 21 |
|
21 | 22 | /// Declaration context
|
22 |
| - DeclContext *DC; |
23 |
| - |
24 |
| - /// Start location of declaration context brace |
25 |
| - SourceLoc StartLoc; |
26 |
| - |
27 |
| - /// Array of all conformed protocols |
28 |
| - SmallVector<swift::ProtocolDecl *, 2> Protocols; |
29 |
| - |
30 |
| - /// Range of internal members in declaration |
31 |
| - DeclRange Range; |
32 |
| - |
33 |
| - bool conformsToCodableProtocol() { |
34 |
| - for (ProtocolDecl *Protocol : Protocols) { |
35 |
| - if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable || |
36 |
| - Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) { |
37 |
| - return true; |
38 |
| - } |
| 23 | + IterableDeclContext *IDC; |
| 24 | + |
| 25 | + AddCodableContext(NominalTypeDecl *nominal) : IDC(nominal){}; |
| 26 | + AddCodableContext(ExtensionDecl *extension) : IDC(extension){}; |
| 27 | + AddCodableContext(std::nullptr_t) : IDC(nullptr){}; |
| 28 | + |
| 29 | + const NominalTypeDecl *getNominal() const { |
| 30 | + switch (IDC->getIterableContextKind()) { |
| 31 | + case IterableDeclContextKind::NominalTypeDecl: |
| 32 | + return cast<NominalTypeDecl>(IDC); |
| 33 | + case IterableDeclContextKind::ExtensionDecl: |
| 34 | + return cast<ExtensionDecl>(IDC)->getExtendedNominal(); |
39 | 35 | }
|
40 |
| - return false; |
| 36 | + assert(false && "unhandled IterableDeclContextKind"); |
41 | 37 | }
|
42 | 38 |
|
43 |
| -public: |
44 |
| - AddCodableContext(NominalTypeDecl *Decl) |
45 |
| - : DC(Decl), StartLoc(Decl->getBraces().Start), |
46 |
| - Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){}; |
47 |
| - |
48 |
| - AddCodableContext(ExtensionDecl *Decl) |
49 |
| - : DC(Decl), StartLoc(Decl->getBraces().Start), |
50 |
| - Protocols(getAllProtocols(Decl->getExtendedNominal())), |
51 |
| - Range(Decl->getMembers()){}; |
52 |
| - |
53 |
| - AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){}; |
54 |
| - |
55 |
| - static AddCodableContext |
56 |
| - getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info); |
57 |
| - |
58 |
| - void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM, |
59 |
| - llvm::raw_ostream &OS); |
60 |
| - |
61 |
| - bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); } |
62 |
| - |
63 |
| - SourceLoc getInsertStartLoc(); |
64 |
| -}; |
65 |
| - |
66 |
| -SourceLoc AddCodableContext::getInsertStartLoc() { |
67 |
| - SourceLoc MaxLoc = StartLoc; |
68 |
| - for (auto Mem : Range) { |
69 |
| - if (Mem->getEndLoc().getOpaquePointerValue() > |
70 |
| - MaxLoc.getOpaquePointerValue()) { |
71 |
| - MaxLoc = Mem->getEndLoc(); |
| 39 | + /// Get the left brace location of the type-or-extension decl. |
| 40 | + SourceLoc getLeftBraceLoc() const { |
| 41 | + switch (IDC->getIterableContextKind()) { |
| 42 | + case IterableDeclContextKind::NominalTypeDecl: |
| 43 | + return cast<NominalTypeDecl>(IDC)->getBraces().Start; |
| 44 | + case IterableDeclContextKind::ExtensionDecl: |
| 45 | + return cast<ExtensionDecl>(IDC)->getBraces().Start; |
72 | 46 | }
|
| 47 | + assert(false && "unhandled IterableDeclContextKind"); |
73 | 48 | }
|
74 |
| - return MaxLoc; |
75 |
| -} |
76 |
| - |
77 |
| -/// Walks an AST and prints the synthesized Codable implementation. |
78 |
| -class SynthesizedCodablePrinter : public ASTWalker { |
79 |
| -private: |
80 |
| - ASTPrinter &Printer; |
81 | 49 |
|
82 |
| -public: |
83 |
| - SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {} |
| 50 | + /// Get the token location where the text should be inserted after. |
| 51 | + SourceLoc getInsertStartLoc() const { |
| 52 | + // Prefer the end of elements. |
| 53 | + for (auto *member : llvm::reverse(IDC->getParsedMembers())) { |
| 54 | + if (isa<AccessorDecl>(member) || isa<VarDecl>(member)) { |
| 55 | + // These are part of 'PatternBindingDecl' but are hoisted in AST. |
| 56 | + continue; |
| 57 | + } |
| 58 | + return member->getEndLoc(); |
| 59 | + } |
84 | 60 |
|
85 |
| - MacroWalking getMacroWalkingBehavior() const override { |
86 |
| - return MacroWalking::Arguments; |
| 61 | + // After the starting brace if empty. |
| 62 | + return getLeftBraceLoc(); |
87 | 63 | }
|
88 | 64 |
|
89 |
| - PreWalkAction walkToDeclPre(Decl *D) override { |
90 |
| - auto *VD = dyn_cast<ValueDecl>(D); |
91 |
| - if (!VD) |
92 |
| - return Action::SkipNode(); |
93 |
| - |
94 |
| - if (!VD->isSynthesized()) { |
95 |
| - return Action::Continue(); |
96 |
| - } |
97 |
| - SmallString<32> Scratch; |
98 |
| - auto name = VD->getName().getString(Scratch); |
99 |
| - // Print all synthesized enums, |
100 |
| - // since Codable can synthesize multiple enums (for associated values). |
101 |
| - auto shouldPrint = |
102 |
| - isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)"; |
103 |
| - if (!shouldPrint) { |
104 |
| - // Some other synthesized decl that we don't want to print. |
105 |
| - return Action::SkipNode(); |
| 65 | + std::string getBaseIndent() const { |
| 66 | + SourceManager &SM = IDC->getDecl()->getASTContext().SourceMgr; |
| 67 | + SourceLoc startLoc = getInsertStartLoc(); |
| 68 | + StringRef extraIndent; |
| 69 | + StringRef currentIndent = |
| 70 | + Lexer::getIndentationForLine(SM, startLoc, &extraIndent); |
| 71 | + if (startLoc == getLeftBraceLoc()) { |
| 72 | + return (currentIndent + extraIndent).str(); |
| 73 | + } else { |
| 74 | + return currentIndent.str(); |
106 | 75 | }
|
| 76 | + } |
107 | 77 |
|
108 |
| - Printer.printNewline(); |
| 78 | + void printInsertText(llvm::raw_ostream &OS) const { |
| 79 | + auto &ctx = IDC->getDecl()->getASTContext(); |
109 | 80 |
|
110 |
| - if (auto enumDecl = dyn_cast<EnumDecl>(D)) { |
111 |
| - // Manually print enum here, since we don't want to print synthesized |
112 |
| - // functions. |
113 |
| - Printer << "enum " << enumDecl->getNameStr(); |
114 |
| - PrintOptions Options; |
115 |
| - Options.PrintSpaceBeforeInheritance = false; |
116 |
| - enumDecl->printInherited(Printer, Options); |
117 |
| - Printer << " {"; |
118 |
| - for (Decl *EC : enumDecl->getAllElements()) { |
119 |
| - Printer.printNewline(); |
120 |
| - Printer << " "; |
121 |
| - EC->print(Printer, Options); |
122 |
| - } |
123 |
| - Printer.printNewline(); |
124 |
| - Printer << "}"; |
125 |
| - return Action::SkipNode(); |
126 |
| - } |
127 |
| - |
128 |
| - PrintOptions Options; |
| 81 | + PrintOptions Options = PrintOptions::printDeclarations(); |
129 | 82 | Options.SynthesizeSugarOnTypes = true;
|
130 | 83 | Options.FunctionDefinitions = true;
|
131 | 84 | Options.VarInitializers = true;
|
132 | 85 | Options.PrintExprs = true;
|
133 |
| - Options.TypeDefinitions = true; |
| 86 | + Options.TypeDefinitions = false; |
| 87 | + Options.PrintSpaceBeforeInheritance = false; |
134 | 88 | Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
|
| 89 | + Options.PrintInternalAccessKeyword = false; |
135 | 90 |
|
| 91 | + std::string baseIndent = getBaseIndent(); |
| 92 | + ExtraIndentStreamPrinter Printer(OS, baseIndent); |
| 93 | + |
| 94 | + // The insertion starts at the end of the last token. |
136 | 95 | Printer.printNewline();
|
137 |
| - D->print(Printer, Options); |
138 | 96 |
|
139 |
| - return Action::SkipNode(); |
| 97 | + // Synthesized 'CodingKeys' are placed in the main nominal decl. |
| 98 | + // Iterate members and look for synthesized enums that conforms to |
| 99 | + // 'CodingKey' protocol. |
| 100 | + auto *codingKeyProto = ctx.getProtocol(KnownProtocolKind::CodingKey); |
| 101 | + for (auto *member : getNominal()->getMembers()) { |
| 102 | + auto *enumD = dyn_cast<EnumDecl>(member); |
| 103 | + if (!enumD || !enumD->isSynthesized()) |
| 104 | + continue; |
| 105 | + llvm::SmallVector<ProtocolConformance *, 1> codingKeyConformance; |
| 106 | + if (!enumD->lookupConformance(codingKeyProto, codingKeyConformance)) |
| 107 | + continue; |
| 108 | + |
| 109 | + // Print the decl, but without the body. |
| 110 | + Printer.printNewline(); |
| 111 | + enumD->print(Printer, Options); |
| 112 | + |
| 113 | + // Manually print elements because CodingKey enums have some synthesized |
| 114 | + // members for the protocol conformance e.g 'init(intValue:)'. |
| 115 | + // We don't want to print them here. |
| 116 | + Printer << " {"; |
| 117 | + Printer.printNewline(); |
| 118 | + Printer.setIndent(2); |
| 119 | + for (auto *elementD : enumD->getAllElements()) { |
| 120 | + elementD->print(Printer, Options); |
| 121 | + Printer.printNewline(); |
| 122 | + } |
| 123 | + Printer.setIndent(0); |
| 124 | + Printer << "}"; |
| 125 | + Printer.printNewline(); |
| 126 | + } |
| 127 | + |
| 128 | + // Look for synthesized witness decls and print them. |
| 129 | + for (auto *conformance : IDC->getLocalConformances()) { |
| 130 | + auto protocol = conformance->getProtocol(); |
| 131 | + auto kind = protocol->getKnownProtocolKind(); |
| 132 | + if (kind == KnownProtocolKind::Encodable || |
| 133 | + kind == KnownProtocolKind::Decodable) { |
| 134 | + for (auto requirement : protocol->getProtocolRequirements()) { |
| 135 | + auto witness = conformance->getWitnessDecl(requirement); |
| 136 | + if (witness->isSynthesized()) { |
| 137 | + Printer.printNewline(); |
| 138 | + witness->print(Printer, Options); |
| 139 | + Printer.printNewline(); |
| 140 | + } |
| 141 | + } |
| 142 | + } |
| 143 | + } |
140 | 144 | }
|
141 |
| -}; |
142 | 145 |
|
143 |
| -void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo, |
144 |
| - SourceManager &SM, |
145 |
| - llvm::raw_ostream &OS) { |
146 |
| - StringRef ExtraIndent; |
147 |
| - StringRef CurrentIndent = |
148 |
| - Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent); |
149 |
| - std::string Indent; |
150 |
| - if (getInsertStartLoc() == StartLoc) { |
151 |
| - Indent = (CurrentIndent + ExtraIndent).str(); |
152 |
| - } else { |
153 |
| - Indent = CurrentIndent.str(); |
| 146 | +public: |
| 147 | + static AddCodableContext getFromCursorInfo(ResolvedCursorInfoPtr Info); |
| 148 | + |
| 149 | + bool isApplicable() const { |
| 150 | + if (!IDC || !getNominal()) |
| 151 | + return false; |
| 152 | + |
| 153 | + // Check if 'IDC' conforms to 'Encodable' and/or 'Decodable' and any of the |
| 154 | + // requirements are synthesized. |
| 155 | + for (auto *conformance : IDC->getLocalConformances()) { |
| 156 | + auto protocol = conformance->getProtocol(); |
| 157 | + auto kind = protocol->getKnownProtocolKind(); |
| 158 | + if (kind == KnownProtocolKind::Encodable || |
| 159 | + kind == KnownProtocolKind::Decodable) { |
| 160 | + // Check if any of the protocol requirements are synthesized. |
| 161 | + for (auto requirement : protocol->getProtocolRequirements()) { |
| 162 | + auto witness = conformance->getWitnessDecl(requirement); |
| 163 | + if (!witness || witness->isSynthesized()) |
| 164 | + return true; |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + return false; |
154 | 169 | }
|
155 | 170 |
|
156 |
| - ExtraIndentStreamPrinter Printer(OS, Indent); |
157 |
| - Printer.printNewline(); |
158 |
| - SynthesizedCodablePrinter Walker(Printer); |
159 |
| - DC->getAsDecl()->walk(Walker); |
160 |
| -} |
| 171 | + void getInsertion(SourceLoc &insertLoc, std::string &insertText) const { |
| 172 | + insertLoc = getInsertStartLoc(); |
| 173 | + llvm::raw_string_ostream OS(insertText); |
| 174 | + printInsertText(OS); |
| 175 | + } |
| 176 | +}; |
161 | 177 |
|
162 | 178 | AddCodableContext
|
163 |
| -AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) { |
| 179 | +AddCodableContext::getFromCursorInfo(ResolvedCursorInfoPtr Info) { |
164 | 180 | auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
|
165 | 181 | if (!ValueRefInfo) {
|
166 |
| - return AddCodableContext(); |
| 182 | + return nullptr; |
167 | 183 | }
|
| 184 | + |
| 185 | + if (auto *ext = ValueRefInfo->getExtTyRef()) { |
| 186 | + // For 'extension Outer.Inner: Codable {}', only 'Inner' part is valid. |
| 187 | + if (ext->getExtendedNominal() == ValueRefInfo->getValueD()) { |
| 188 | + return AddCodableContext(ext); |
| 189 | + } else { |
| 190 | + return nullptr; |
| 191 | + } |
| 192 | + } |
| 193 | + |
168 | 194 | if (!ValueRefInfo->isRef()) {
|
169 |
| - if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) { |
170 |
| - return AddCodableContext(NomDecl); |
| 195 | + if (auto *nominal = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) { |
| 196 | + return AddCodableContext(nominal); |
171 | 197 | }
|
172 | 198 | }
|
173 |
| - // TODO: support extensions |
174 |
| - // (would need to get synthesized nodes from the main decl, |
175 |
| - // and only if it's in the same file?) |
176 |
| - return AddCodableContext(); |
| 199 | + |
| 200 | + return nullptr; |
177 | 201 | }
|
178 | 202 | } // namespace
|
179 | 203 |
|
180 | 204 | bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
|
181 | 205 | ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
|
182 |
| - return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid(); |
| 206 | + return AddCodableContext::getFromCursorInfo(Tok).isApplicable(); |
183 | 207 | }
|
184 | 208 |
|
185 | 209 | bool RefactoringActionAddExplicitCodableImplementation::performChange() {
|
186 |
| - auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo); |
| 210 | + auto Context = AddCodableContext::getFromCursorInfo(CursorInfo); |
| 211 | + assert(Context.isApplicable() && |
| 212 | + "Should not run performChange when refactoring is not applicable"); |
187 | 213 |
|
188 |
| - SmallString<64> Buffer; |
189 |
| - llvm::raw_svector_ostream OS(Buffer); |
190 |
| - Context.printInsertionText(CursorInfo, SM, OS); |
| 214 | + SourceLoc insertLoc; |
| 215 | + std::string insertText; |
| 216 | + Context.getInsertion(insertLoc, insertText); |
191 | 217 |
|
192 |
| - EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str()); |
| 218 | + EditConsumer.insertAfter(SM, insertLoc, insertText); |
193 | 219 | return false;
|
194 | 220 | }
|
0 commit comments