Skip to content

Commit 8d92241

Browse files
committed
[Refactoring] Support async for function extraction
Adapt the `ThrowingEntityAnalyzer` to pick up any `await` keywords and add an `async` to the extracted function if necessary along with an `await` for its call. rdar://72199949
1 parent 1372797 commit 8d92241

File tree

8 files changed

+105
-33
lines changed

8 files changed

+105
-33
lines changed

include/swift/IDE/Utils.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "swift/Basic/LLVM.h"
1818
#include "swift/AST/ASTNode.h"
1919
#include "swift/AST/DeclNameLoc.h"
20+
#include "swift/AST/Effects.h"
2021
#include "swift/AST/Module.h"
2122
#include "swift/AST/ASTPrinter.h"
2223
#include "swift/IDE/SourceEntityWalker.h"
@@ -345,7 +346,7 @@ struct ResolvedRangeInfo {
345346
ArrayRef<Token> TokensInRange;
346347
CharSourceRange ContentRange;
347348
bool HasSingleEntry;
348-
bool ThrowingUnhandledError;
349+
PossibleEffects UnhandledEffects;
349350
OrphanKind Orphan;
350351

351352
// The topmost ast nodes contained in the given range.
@@ -359,15 +360,15 @@ struct ResolvedRangeInfo {
359360
ArrayRef<Token> TokensInRange,
360361
DeclContext* RangeContext,
361362
Expr *CommonExprParent, bool HasSingleEntry,
362-
bool ThrowingUnhandledError,
363+
PossibleEffects UnhandledEffects,
363364
OrphanKind Orphan, ArrayRef<ASTNode> ContainedNodes,
364365
ArrayRef<DeclaredDecl> DeclaredDecls,
365366
ArrayRef<ReferencedDecl> ReferencedDecls): Kind(Kind),
366367
ExitInfo(ExitInfo),
367368
TokensInRange(TokensInRange),
368369
ContentRange(calculateContentRange(TokensInRange)),
369370
HasSingleEntry(HasSingleEntry),
370-
ThrowingUnhandledError(ThrowingUnhandledError),
371+
UnhandledEffects(UnhandledEffects),
371372
Orphan(Orphan), ContainedNodes(ContainedNodes),
372373
DeclaredDecls(DeclaredDecls),
373374
ReferencedDecls(ReferencedDecls),
@@ -376,7 +377,7 @@ struct ResolvedRangeInfo {
376377
ResolvedRangeInfo(ArrayRef<Token> TokensInRange) :
377378
ResolvedRangeInfo(RangeKind::Invalid, {nullptr, ExitState::Unsure},
378379
TokensInRange, nullptr, /*Commom Expr Parent*/nullptr,
379-
/*Single entry*/true, /*unhandled error*/false,
380+
/*Single entry*/true, /*UnhandledEffects*/{},
380381
OrphanKind::None, {}, {}, {}) {}
381382
ResolvedRangeInfo(): ResolvedRangeInfo(ArrayRef<Token>()) {}
382383
void print(llvm::raw_ostream &OS) const;

lib/IDE/IDERequests.cpp

+28-27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "swift/AST/ASTPrinter.h"
1414
#include "swift/AST/Decl.h"
15+
#include "swift/AST/Effects.h"
1516
#include "swift/AST/NameLookup.h"
1617
#include "swift/AST/ASTDemangler.h"
1718
#include "swift/Basic/SourceManager.h"
@@ -377,45 +378,45 @@ class RangeResolver : public SourceEntityWalker {
377378
ResolvedRangeInfo resolve();
378379
};
379380

380-
static bool hasUnhandledError(ArrayRef<ASTNode> Nodes) {
381-
class ThrowingEntityAnalyzer : public SourceEntityWalker {
382-
bool Throwing;
381+
static PossibleEffects getUnhandledEffects(ArrayRef<ASTNode> Nodes) {
382+
class EffectsAnalyzer : public SourceEntityWalker {
383+
PossibleEffects Effects;
383384
public:
384-
ThrowingEntityAnalyzer(): Throwing(false) {}
385385
bool walkToStmtPre(Stmt *S) override {
386386
if (auto DCS = dyn_cast<DoCatchStmt>(S)) {
387387
if (DCS->isSyntacticallyExhaustive())
388388
return false;
389-
Throwing = true;
389+
Effects |= EffectKind::Throws;
390390
} else if (isa<ThrowStmt>(S)) {
391-
Throwing = true;
391+
Effects |= EffectKind::Throws;
392392
}
393-
return !Throwing;
393+
return true;
394394
}
395395
bool walkToExprPre(Expr *E) override {
396396
// Don't walk into closures, they only produce effects when called.
397397
if (isa<ClosureExpr>(E))
398398
return false;
399-
400-
if (isa<TryExpr>(E)) {
401-
Throwing = true;
402-
}
403-
return !Throwing;
399+
400+
if (isa<TryExpr>(E))
401+
Effects |= EffectKind::Throws;
402+
if (isa<AwaitExpr>(E))
403+
Effects |= EffectKind::Async;
404+
405+
return true;
404406
}
405407
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
406408
return false;
407409
}
408-
bool walkToDeclPost(Decl *D) override { return !Throwing; }
409-
bool walkToStmtPost(Stmt *S) override { return !Throwing; }
410-
bool walkToExprPost(Expr *E) override { return !Throwing; }
411-
bool isThrowing() { return Throwing; }
410+
PossibleEffects getEffects() const { return Effects; }
412411
};
413412

414-
return Nodes.end() != std::find_if(Nodes.begin(), Nodes.end(), [](ASTNode N) {
415-
ThrowingEntityAnalyzer Analyzer;
413+
PossibleEffects Effects;
414+
for (auto N : Nodes) {
415+
EffectsAnalyzer Analyzer;
416416
Analyzer.walk(N);
417-
return Analyzer.isThrowing();
418-
});
417+
Effects |= Analyzer.getEffects();
418+
}
419+
return Effects;
419420
}
420421

421422
struct RangeResolver::Implementation {
@@ -553,7 +554,7 @@ struct RangeResolver::Implementation {
553554
assert(ContainedASTNodes.size() == 1);
554555
// Single node implies single entry point, or is it?
555556
bool SingleEntry = true;
556-
bool UnhandledError = hasUnhandledError({Node});
557+
auto UnhandledEffects = getUnhandledEffects({Node});
557558
OrphanKind Kind = getOrphanKind(ContainedASTNodes);
558559
if (Node.is<Expr*>())
559560
return ResolvedRangeInfo(RangeKind::SingleExpression,
@@ -562,7 +563,7 @@ struct RangeResolver::Implementation {
562563
getImmediateContext(),
563564
/*Common Parent Expr*/nullptr,
564565
SingleEntry,
565-
UnhandledError, Kind,
566+
UnhandledEffects, Kind,
566567
llvm::makeArrayRef(ContainedASTNodes),
567568
llvm::makeArrayRef(DeclaredDecls),
568569
llvm::makeArrayRef(ReferencedDecls));
@@ -573,7 +574,7 @@ struct RangeResolver::Implementation {
573574
getImmediateContext(),
574575
/*Common Parent Expr*/nullptr,
575576
SingleEntry,
576-
UnhandledError, Kind,
577+
UnhandledEffects, Kind,
577578
llvm::makeArrayRef(ContainedASTNodes),
578579
llvm::makeArrayRef(DeclaredDecls),
579580
llvm::makeArrayRef(ReferencedDecls));
@@ -585,7 +586,7 @@ struct RangeResolver::Implementation {
585586
getImmediateContext(),
586587
/*Common Parent Expr*/nullptr,
587588
SingleEntry,
588-
UnhandledError, Kind,
589+
UnhandledEffects, Kind,
589590
llvm::makeArrayRef(ContainedASTNodes),
590591
llvm::makeArrayRef(DeclaredDecls),
591592
llvm::makeArrayRef(ReferencedDecls));
@@ -646,7 +647,7 @@ struct RangeResolver::Implementation {
646647
getImmediateContext(),
647648
Parent,
648649
hasSingleEntryPoint(ContainedASTNodes),
649-
hasUnhandledError(ContainedASTNodes),
650+
getUnhandledEffects(ContainedASTNodes),
650651
getOrphanKind(ContainedASTNodes),
651652
llvm::makeArrayRef(ContainedASTNodes),
652653
llvm::makeArrayRef(DeclaredDecls),
@@ -893,7 +894,7 @@ struct RangeResolver::Implementation {
893894
TokensInRange,
894895
getImmediateContext(), nullptr,
895896
hasSingleEntryPoint(ContainedASTNodes),
896-
hasUnhandledError(ContainedASTNodes),
897+
getUnhandledEffects(ContainedASTNodes),
897898
getOrphanKind(ContainedASTNodes),
898899
llvm::makeArrayRef(ContainedASTNodes),
899900
llvm::makeArrayRef(DeclaredDecls),
@@ -908,7 +909,7 @@ struct RangeResolver::Implementation {
908909
getImmediateContext(),
909910
/*Common Parent Expr*/ nullptr,
910911
/*SinleEntry*/ true,
911-
hasUnhandledError(ContainedASTNodes),
912+
getUnhandledEffects(ContainedASTNodes),
912913
getOrphanKind(ContainedASTNodes),
913914
llvm::makeArrayRef(ContainedASTNodes),
914915
llvm::makeArrayRef(DeclaredDecls),

lib/IDE/Refactoring.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,9 @@ bool RefactoringActionExtractFunction::performChange() {
13041304
}
13051305
OS << ")";
13061306

1307-
if (RangeInfo.ThrowingUnhandledError)
1307+
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
1308+
OS << " async";
1309+
if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
13081310
OS << " " << tok::kw_throws;
13091311

13101312
bool InsertedReturnType = false;
@@ -1335,6 +1337,8 @@ bool RefactoringActionExtractFunction::performChange() {
13351337

13361338
if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
13371339
OS << tok::kw_try << " ";
1340+
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
1341+
OS << "await ";
13381342

13391343
CallNameOffset = Buffer.size() - ReplaceBegin;
13401344
OS << PreferredName << "(";

lib/IDE/SwiftSourceDocInfo.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,12 @@ void ResolvedRangeInfo::print(llvm::raw_ostream &OS) const {
724724
OS << "<Entry>Multi</Entry>\n";
725725
}
726726

727-
if (ThrowingUnhandledError) {
727+
if (UnhandledEffects.contains(EffectKind::Throws)) {
728728
OS << "<Error>Throwing</Error>\n";
729729
}
730+
if (UnhandledEffects.contains(EffectKind::Async)) {
731+
OS << "<Effect>Async</Effect>\n";
732+
}
730733

731734
if (Orphan != OrphanKind::None) {
732735
OS << "<Orphan>";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
func longLongLongJourney() async -> Int { 0 }
2+
func longLongLongAwryJourney() async throws -> Int { 0 }
3+
func consumesAsync(_ fn: () async throws -> Void) rethrows {}
4+
5+
fileprivate func new_name() async -> Int {
6+
return await longLongLongJourney()
7+
}
8+
9+
func testThrowingClosure() async throws -> Int {
10+
let x = await new_name()
11+
let y = try await longLongLongAwryJourney() + 1
12+
try consumesAsync { try await longLongLongAwryJourney() }
13+
return x + y
14+
}
15+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
func longLongLongJourney() async -> Int { 0 }
2+
func longLongLongAwryJourney() async throws -> Int { 0 }
3+
func consumesAsync(_ fn: () async throws -> Void) rethrows {}
4+
5+
fileprivate func new_name() async throws -> Int {
6+
return try await longLongLongAwryJourney() + 1
7+
}
8+
9+
func testThrowingClosure() async throws -> Int {
10+
let x = await longLongLongJourney()
11+
let y = try await new_name()
12+
try consumesAsync { try await longLongLongAwryJourney() }
13+
return x + y
14+
}
15+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
func longLongLongJourney() async -> Int { 0 }
2+
func longLongLongAwryJourney() async throws -> Int { 0 }
3+
func consumesAsync(_ fn: () async throws -> Void) rethrows {}
4+
5+
fileprivate func new_name() throws {
6+
try consumesAsync { try await longLongLongAwryJourney() }
7+
}
8+
9+
func testThrowingClosure() async throws -> Int {
10+
let x = await longLongLongJourney()
11+
let y = try await longLongLongAwryJourney() + 1
12+
try new_name()
13+
return x + y
14+
}
15+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
func longLongLongJourney() async -> Int { 0 }
2+
func longLongLongAwryJourney() async throws -> Int { 0 }
3+
func consumesAsync(_ fn: () async throws -> Void) rethrows {}
4+
5+
func testThrowingClosure() async throws -> Int {
6+
let x = await longLongLongJourney()
7+
let y = try await longLongLongAwryJourney() + 1
8+
try consumesAsync { try await longLongLongAwryJourney() }
9+
return x + y
10+
}
11+
12+
// RUN: %empty-directory(%t.result)
13+
// RUN: %refactor -extract-function -source-filename %s -pos=6:11 -end-pos=6:38 >> %t.result/async1.swift
14+
// RUN: diff -u %S/Outputs/await/async1.swift.expected %t.result/async1.swift
15+
// RUN: %refactor -extract-function -source-filename %s -pos=7:11 -end-pos=7:50 >> %t.result/async2.swift
16+
// RUN: diff -u %S/Outputs/await/async2.swift.expected %t.result/async2.swift
17+
// RUN: %refactor -extract-function -source-filename %s -pos=8:1 -end-pos=8:60 >> %t.result/consumes_async.swift
18+
// RUN: diff -u %S/Outputs/await/consumes_async.swift.expected %t.result/consumes_async.swift

0 commit comments

Comments
 (0)