Skip to content

Commit 39606e6

Browse files
committed
[refactoring] Rework "add codable implementation" refactoring
* Support extensions including conditional conformance * Correct access modifiers * More correct lookup for the synthesized declarations * Avoid printing decls in nested types (rdar://98025945)
1 parent 990c870 commit 39606e6

22 files changed

+644
-145
lines changed

Diff for: lib/Refactoring/AddExplicitCodableImplementation.cpp

+155-129
Original file line numberDiff line numberDiff line change
@@ -12,183 +12,209 @@
1212

1313
#include "RefactoringActions.h"
1414
#include "Utils.h"
15+
#include "swift/AST/ProtocolConformance.h"
1516

1617
using namespace swift::refactoring;
1718

1819
namespace {
1920
class AddCodableContext {
2021

2122
/// Declaration context
22-
DeclContext *DC;
23-
24-
/// Start location of declaration context brace
25-
SourceLoc StartLoc;
26-
27-
/// Array of all conformed protocols
28-
SmallVector<swift::ProtocolDecl *, 2> Protocols;
29-
30-
/// Range of internal members in declaration
31-
DeclRange Range;
32-
33-
bool conformsToCodableProtocol() {
34-
for (ProtocolDecl *Protocol : Protocols) {
35-
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
36-
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
37-
return true;
38-
}
23+
IterableDeclContext *IDC;
24+
25+
AddCodableContext(NominalTypeDecl *nominal) : IDC(nominal){};
26+
AddCodableContext(ExtensionDecl *extension) : IDC(extension){};
27+
AddCodableContext(std::nullptr_t) : IDC(nullptr){};
28+
29+
const NominalTypeDecl *getNominal() const {
30+
switch (IDC->getIterableContextKind()) {
31+
case IterableDeclContextKind::NominalTypeDecl:
32+
return cast<NominalTypeDecl>(IDC);
33+
case IterableDeclContextKind::ExtensionDecl:
34+
return cast<ExtensionDecl>(IDC)->getExtendedNominal();
3935
}
40-
return false;
36+
assert(false && "unhandled IterableDeclContextKind");
4137
}
4238

43-
public:
44-
AddCodableContext(NominalTypeDecl *Decl)
45-
: DC(Decl), StartLoc(Decl->getBraces().Start),
46-
Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){};
47-
48-
AddCodableContext(ExtensionDecl *Decl)
49-
: DC(Decl), StartLoc(Decl->getBraces().Start),
50-
Protocols(getAllProtocols(Decl->getExtendedNominal())),
51-
Range(Decl->getMembers()){};
52-
53-
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
54-
55-
static AddCodableContext
56-
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
57-
58-
void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM,
59-
llvm::raw_ostream &OS);
60-
61-
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
62-
63-
SourceLoc getInsertStartLoc();
64-
};
65-
66-
SourceLoc AddCodableContext::getInsertStartLoc() {
67-
SourceLoc MaxLoc = StartLoc;
68-
for (auto Mem : Range) {
69-
if (Mem->getEndLoc().getOpaquePointerValue() >
70-
MaxLoc.getOpaquePointerValue()) {
71-
MaxLoc = Mem->getEndLoc();
39+
/// Get the left brace location of the type-or-extension decl.
40+
SourceLoc getLeftBraceLoc() const {
41+
switch (IDC->getIterableContextKind()) {
42+
case IterableDeclContextKind::NominalTypeDecl:
43+
return cast<NominalTypeDecl>(IDC)->getBraces().Start;
44+
case IterableDeclContextKind::ExtensionDecl:
45+
return cast<ExtensionDecl>(IDC)->getBraces().Start;
7246
}
47+
assert(false && "unhandled IterableDeclContextKind");
7348
}
74-
return MaxLoc;
75-
}
76-
77-
/// Walks an AST and prints the synthesized Codable implementation.
78-
class SynthesizedCodablePrinter : public ASTWalker {
79-
private:
80-
ASTPrinter &Printer;
8149

82-
public:
83-
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
50+
/// Get the token location where the text should be inserted after.
51+
SourceLoc getInsertStartLoc() const {
52+
// Prefer the end of elements.
53+
for (auto *member : llvm::reverse(IDC->getParsedMembers())) {
54+
if (isa<AccessorDecl>(member) || isa<VarDecl>(member)) {
55+
// These are part of 'PatternBindingDecl' but are hoisted in AST.
56+
continue;
57+
}
58+
return member->getEndLoc();
59+
}
8460

85-
MacroWalking getMacroWalkingBehavior() const override {
86-
return MacroWalking::Arguments;
61+
// After the starting brace if empty.
62+
return getLeftBraceLoc();
8763
}
8864

89-
PreWalkAction walkToDeclPre(Decl *D) override {
90-
auto *VD = dyn_cast<ValueDecl>(D);
91-
if (!VD)
92-
return Action::SkipNode();
93-
94-
if (!VD->isSynthesized()) {
95-
return Action::Continue();
96-
}
97-
SmallString<32> Scratch;
98-
auto name = VD->getName().getString(Scratch);
99-
// Print all synthesized enums,
100-
// since Codable can synthesize multiple enums (for associated values).
101-
auto shouldPrint =
102-
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
103-
if (!shouldPrint) {
104-
// Some other synthesized decl that we don't want to print.
105-
return Action::SkipNode();
65+
std::string getBaseIndent() const {
66+
SourceManager &SM = IDC->getDecl()->getASTContext().SourceMgr;
67+
SourceLoc startLoc = getInsertStartLoc();
68+
StringRef extraIndent;
69+
StringRef currentIndent =
70+
Lexer::getIndentationForLine(SM, startLoc, &extraIndent);
71+
if (startLoc == getLeftBraceLoc()) {
72+
return (currentIndent + extraIndent).str();
73+
} else {
74+
return currentIndent.str();
10675
}
76+
}
10777

108-
Printer.printNewline();
78+
void printInsertText(llvm::raw_ostream &OS) const {
79+
auto &ctx = IDC->getDecl()->getASTContext();
10980

110-
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
111-
// Manually print enum here, since we don't want to print synthesized
112-
// functions.
113-
Printer << "enum " << enumDecl->getNameStr();
114-
PrintOptions Options;
115-
Options.PrintSpaceBeforeInheritance = false;
116-
enumDecl->printInherited(Printer, Options);
117-
Printer << " {";
118-
for (Decl *EC : enumDecl->getAllElements()) {
119-
Printer.printNewline();
120-
Printer << " ";
121-
EC->print(Printer, Options);
122-
}
123-
Printer.printNewline();
124-
Printer << "}";
125-
return Action::SkipNode();
126-
}
127-
128-
PrintOptions Options;
81+
PrintOptions Options = PrintOptions::printDeclarations();
12982
Options.SynthesizeSugarOnTypes = true;
13083
Options.FunctionDefinitions = true;
13184
Options.VarInitializers = true;
13285
Options.PrintExprs = true;
133-
Options.TypeDefinitions = true;
86+
Options.TypeDefinitions = false;
87+
Options.PrintSpaceBeforeInheritance = false;
13488
Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
89+
Options.PrintInternalAccessKeyword = false;
13590

91+
std::string baseIndent = getBaseIndent();
92+
ExtraIndentStreamPrinter Printer(OS, baseIndent);
93+
94+
// The insertion starts at the end of the last token.
13695
Printer.printNewline();
137-
D->print(Printer, Options);
13896

139-
return Action::SkipNode();
97+
// Synthesized 'CodingKeys' are placed in the main nominal decl.
98+
// Iterate members and look for synthesized enums that conforms to
99+
// 'CodingKey' protocol.
100+
auto *codingKeyProto = ctx.getProtocol(KnownProtocolKind::CodingKey);
101+
for (auto *member : getNominal()->getMembers()) {
102+
auto *enumD = dyn_cast<EnumDecl>(member);
103+
if (!enumD || !enumD->isSynthesized())
104+
continue;
105+
llvm::SmallVector<ProtocolConformance *, 1> codingKeyConformance;
106+
if (!enumD->lookupConformance(codingKeyProto, codingKeyConformance))
107+
continue;
108+
109+
// Print the decl, but without the body.
110+
Printer.printNewline();
111+
enumD->print(Printer, Options);
112+
113+
// Manually print elements because CodingKey enums have some synthesized
114+
// members for the protocol conformance e.g 'init(intValue:)'.
115+
// We don't want to print them here.
116+
Printer << " {";
117+
Printer.printNewline();
118+
Printer.setIndent(2);
119+
for (auto *elementD : enumD->getAllElements()) {
120+
elementD->print(Printer, Options);
121+
Printer.printNewline();
122+
}
123+
Printer.setIndent(0);
124+
Printer << "}";
125+
Printer.printNewline();
126+
}
127+
128+
// Look for synthesized witness decls and print them.
129+
for (auto *conformance : IDC->getLocalConformances()) {
130+
auto protocol = conformance->getProtocol();
131+
auto kind = protocol->getKnownProtocolKind();
132+
if (kind == KnownProtocolKind::Encodable ||
133+
kind == KnownProtocolKind::Decodable) {
134+
for (auto requirement : protocol->getProtocolRequirements()) {
135+
auto witness = conformance->getWitnessDecl(requirement);
136+
if (witness->isSynthesized()) {
137+
Printer.printNewline();
138+
witness->print(Printer, Options);
139+
Printer.printNewline();
140+
}
141+
}
142+
}
143+
}
140144
}
141-
};
142145

143-
void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo,
144-
SourceManager &SM,
145-
llvm::raw_ostream &OS) {
146-
StringRef ExtraIndent;
147-
StringRef CurrentIndent =
148-
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
149-
std::string Indent;
150-
if (getInsertStartLoc() == StartLoc) {
151-
Indent = (CurrentIndent + ExtraIndent).str();
152-
} else {
153-
Indent = CurrentIndent.str();
146+
public:
147+
static AddCodableContext getFromCursorInfo(ResolvedCursorInfoPtr Info);
148+
149+
bool isApplicable() const {
150+
if (!IDC || !getNominal())
151+
return false;
152+
153+
// Check if 'IDC' conforms to 'Encodable' and/or 'Decodable' and any of the
154+
// requirements are synthesized.
155+
for (auto *conformance : IDC->getLocalConformances()) {
156+
auto protocol = conformance->getProtocol();
157+
auto kind = protocol->getKnownProtocolKind();
158+
if (kind == KnownProtocolKind::Encodable ||
159+
kind == KnownProtocolKind::Decodable) {
160+
// Check if any of the protocol requirements are synthesized.
161+
for (auto requirement : protocol->getProtocolRequirements()) {
162+
auto witness = conformance->getWitnessDecl(requirement);
163+
if (!witness || witness->isSynthesized())
164+
return true;
165+
}
166+
}
167+
}
168+
return false;
154169
}
155170

156-
ExtraIndentStreamPrinter Printer(OS, Indent);
157-
Printer.printNewline();
158-
SynthesizedCodablePrinter Walker(Printer);
159-
DC->getAsDecl()->walk(Walker);
160-
}
171+
void getInsertion(SourceLoc &insertLoc, std::string &insertText) const {
172+
insertLoc = getInsertStartLoc();
173+
llvm::raw_string_ostream OS(insertText);
174+
printInsertText(OS);
175+
}
176+
};
161177

162178
AddCodableContext
163-
AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
179+
AddCodableContext::getFromCursorInfo(ResolvedCursorInfoPtr Info) {
164180
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
165181
if (!ValueRefInfo) {
166-
return AddCodableContext();
182+
return nullptr;
167183
}
184+
185+
if (auto *ext = ValueRefInfo->getExtTyRef()) {
186+
// For 'extension Outer.Inner: Codable {}', only 'Inner' part is valid.
187+
if (ext->getExtendedNominal() == ValueRefInfo->getValueD()) {
188+
return AddCodableContext(ext);
189+
} else {
190+
return nullptr;
191+
}
192+
}
193+
168194
if (!ValueRefInfo->isRef()) {
169-
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
170-
return AddCodableContext(NomDecl);
195+
if (auto *nominal = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
196+
return AddCodableContext(nominal);
171197
}
172198
}
173-
// TODO: support extensions
174-
// (would need to get synthesized nodes from the main decl,
175-
// and only if it's in the same file?)
176-
return AddCodableContext();
199+
200+
return nullptr;
177201
}
178202
} // namespace
179203

180204
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
181205
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
182-
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
206+
return AddCodableContext::getFromCursorInfo(Tok).isApplicable();
183207
}
184208

185209
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
186-
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
210+
auto Context = AddCodableContext::getFromCursorInfo(CursorInfo);
211+
assert(Context.isApplicable() &&
212+
"Should not run performChange when refactoring is not applicable");
187213

188-
SmallString<64> Buffer;
189-
llvm::raw_svector_ostream OS(Buffer);
190-
Context.printInsertionText(CursorInfo, SM, OS);
214+
SourceLoc insertLoc;
215+
std::string insertText;
216+
Context.getInsertion(insertLoc, insertText);
191217

192-
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
218+
EditConsumer.insertAfter(SM, insertLoc, insertText);
193219
return false;
194220
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
private struct PrivateS: Codable {
3+
let value: Int
4+
}
5+
6+
7+
public struct PublicS: Codable {
8+
let value: Int
9+
}
10+
11+
12+
open class OpenC: Codable {
13+
let value: Int
14+
15+
private enum CodingKeys: CodingKey {
16+
case value
17+
}
18+
19+
required public init(from decoder: any Decoder) throws {
20+
let container: KeyedDecodingContainer<OpenC.CodingKeys> = try decoder.container(keyedBy: OpenC.CodingKeys.self)
21+
22+
self.value = try container.decode(Int.self, forKey: OpenC.CodingKeys.value)
23+
24+
}
25+
26+
open func encode(to encoder: any Encoder) throws {
27+
var container: KeyedEncodingContainer<OpenC.CodingKeys> = encoder.container(keyedBy: OpenC.CodingKeys.self)
28+
29+
try container.encode(self.value, forKey: OpenC.CodingKeys.value)
30+
}
31+
}
32+

0 commit comments

Comments
 (0)