Skip to content

Commit 3dd4df2

Browse files
committed
[Typed throws] Location based lookup for the thrown error type
Introduce a new API to find the AST node that catches or rethrows an error thrown from the given source location. Use it to determine the thrown error type to use for type checking a `throw` statement, which begins as `any Error` within a `do..catch` and is later refined.
1 parent 7e9013d commit 3dd4df2

File tree

9 files changed

+226
-9
lines changed

9 files changed

+226
-9
lines changed

include/swift/AST/ASTScope.h

+19
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define SWIFT_AST_AST_SCOPE_H
3030

3131
#include "swift/AST/ASTNode.h"
32+
#include "swift/AST/CatchNode.h"
3233
#include "swift/AST/NameLookup.h"
3334
#include "swift/AST/SimpleRequest.h"
3435
#include "swift/Basic/Compiler.h"
@@ -85,6 +86,7 @@ class SILGenFunction;
8586

8687
namespace ast_scope {
8788
class ASTScopeImpl;
89+
class BraceStmtScope;
8890
class GenericTypeOrExtensionScope;
8991
class IterableTypeScope;
9092
class TypeAliasScope;
@@ -211,6 +213,7 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
211213
#pragma mark common queries
212214
public:
213215
virtual NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const;
216+
virtual NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const;
214217
virtual ASTContext &getASTContext() const;
215218
virtual NullablePtr<Decl> getDeclIfAny() const { return nullptr; };
216219
virtual NullablePtr<Stmt> getStmtIfAny() const { return nullptr; };
@@ -287,10 +290,18 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
287290
SourceFile *sourceFile, SourceLoc loc,
288291
llvm::function_ref<bool(ASTScope::PotentialMacro)> consume);
289292

293+
static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);
294+
290295
/// Scopes that cannot bind variables may set this to true to create more
291296
/// compact scope tree in the debug info.
292297
virtual bool ignoreInDebugInfo() const { return false; }
293298

299+
/// If this scope node represents a potential catch node, return body the
300+
/// AST node describing the catch (a function, closure, or do...catch) and
301+
/// the node of it's "body", i.e., the brace statement from which errors
302+
/// thrown will be caught by that node.
303+
virtual std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const;
304+
294305
#pragma mark - - lookup- starting point
295306
private:
296307
static const ASTScopeImpl *findStartingScopeForLookup(SourceFile *,
@@ -824,6 +835,8 @@ class FunctionBodyScope : public ASTScopeImpl {
824835
Decl *getDecl() const { return decl; }
825836
bool ignoreInDebugInfo() const override { return true; }
826837

838+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
839+
827840
protected:
828841
bool lookupLocalsOrMembers(DeclConsumer) const override;
829842

@@ -1069,6 +1082,8 @@ class ClosureParametersScope final : public ASTScopeImpl {
10691082
NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const override {
10701083
return closureExpr;
10711084
}
1085+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
1086+
10721087
NullablePtr<Expr> getExprIfAny() const override { return closureExpr; }
10731088
Expr *getExpr() const { return closureExpr; }
10741089
bool ignoreInDebugInfo() const override { return true; }
@@ -1440,6 +1455,8 @@ class DoCatchStmtScope final : public AbstractStmtScope {
14401455
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
14411456

14421457
public:
1458+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
1459+
14431460
std::string getClassName() const override;
14441461
Stmt *getStmt() const override { return stmt; }
14451462
};
@@ -1648,6 +1665,8 @@ class BraceStmtScope final : public AbstractStmtScope {
16481665
NullablePtr<AbstractClosureExpr> parentClosureIfAny() const; // public??
16491666
Stmt *getStmt() const override { return stmt; }
16501667

1668+
NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const override;
1669+
16511670
protected:
16521671
bool lookupLocalsOrMembers(DeclConsumer) const override;
16531672
};

include/swift/AST/CatchNode.h

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===--- CatchNode.h - An AST node that catches errors -----------*- C++-*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_AST_CATCHNODE_H
14+
#define SWIFT_AST_CATCHNODE_H
15+
16+
#include "llvm/ADT/Optional.h"
17+
#include "llvm/ADT/PointerUnion.h"
18+
#include "swift/AST/Decl.h"
19+
#include "swift/AST/Expr.h"
20+
#include "swift/AST/Stmt.h"
21+
22+
namespace swift {
23+
24+
/// An AST node that represents a point where a thrown error can be caught and
25+
/// or rethrown, which includes functions do...catch statements.
26+
class CatchNode: public llvm::PointerUnion<
27+
AbstractFunctionDecl *, AbstractClosureExpr *, DoCatchStmt *
28+
> {
29+
public:
30+
using PointerUnion::PointerUnion;
31+
32+
/// Determine the thrown error type within the region of this catch node
33+
/// where it will catch (and possibly rethrow) errors. All of the errors
34+
/// thrown from within that region will be converted to this error type.
35+
///
36+
/// Returns the thrown error type for a throwing context, or \c llvm::None
37+
/// if this is a non-throwing context.
38+
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
39+
};
40+
41+
} // end namespace swift
42+
43+
#endif // SWIFT_AST_CATCHNODE_H

include/swift/AST/NameLookup.h

+20
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SWIFT_AST_NAME_LOOKUP_H
1919

2020
#include "swift/AST/ASTVisitor.h"
21+
#include "swift/AST/CatchNode.h"
2122
#include "swift/AST/GenericSignature.h"
2223
#include "swift/AST/Identifier.h"
2324
#include "swift/AST/Module.h"
@@ -833,6 +834,25 @@ class ASTScope : public ASTAllocated<ASTScope> {
833834
SourceFile *sourceFile, SourceLoc loc,
834835
llvm::function_ref<bool(PotentialMacro macro)> consume);
835836

837+
/// Look up the scope tree for the nearest point at which an error thrown from
838+
/// this location can be caught or rethrown.
839+
///
840+
/// For example, given this code:
841+
///
842+
/// \code
843+
/// func f() throws {
844+
/// do {
845+
/// try g() // A
846+
/// } catch {
847+
/// throw ErrorWrapper(error) // B
848+
/// }
849+
/// }
850+
/// \endcode
851+
///
852+
/// At the point marked A, the catch node is the enclosing do...catch
853+
/// statement. At the point marked B, the catch node is the function itself.
854+
static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);
855+
836856
SWIFT_DEBUG_DUMP;
837857
void print(llvm::raw_ostream &) const;
838858
void dumpOneScopeMapLocation(std::pair<unsigned, unsigned>);

lib/AST/ASTScope.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ void ASTScope::lookupEnclosingMacroScope(
6666
return ASTScopeImpl::lookupEnclosingMacroScope(sourceFile, loc, body);
6767
}
6868

69+
CatchNode ASTScope::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
70+
return ASTScopeImpl::lookupCatchNode(module, loc);
71+
}
72+
6973
#if SWIFT_COMPILER_IS_MSVC
7074
#pragma warning(push)
7175
#pragma warning(disable : 4996)
@@ -97,10 +101,65 @@ NullablePtr<AbstractClosureExpr> BraceStmtScope::parentClosureIfAny() const {
97101
return !getParent() ? nullptr : getParent().get()->getClosureIfClosureScope();
98102
}
99103

104+
NullablePtr<const BraceStmtScope> BraceStmtScope::getAsBraceStmtScope() const {
105+
return this;
106+
}
107+
100108
NullablePtr<AbstractClosureExpr> ASTScopeImpl::getClosureIfClosureScope() const {
101109
return nullptr;
102110
}
103111

112+
NullablePtr<const BraceStmtScope> ASTScopeImpl::getAsBraceStmtScope() const {
113+
return nullptr;
114+
}
115+
116+
std::pair<CatchNode, const BraceStmtScope *>
117+
ASTScopeImpl::getCatchNodeBody() const {
118+
return { nullptr, nullptr };
119+
}
120+
121+
std::pair<CatchNode, const BraceStmtScope *>
122+
ClosureParametersScope::getCatchNodeBody() const {
123+
const BraceStmtScope *body = nullptr;
124+
const auto &children = getChildren();
125+
if (!children.empty()) {
126+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
127+
assert(body && "Not a brace statement?");
128+
}
129+
return { const_cast<AbstractClosureExpr *>(closureExpr), body };
130+
}
131+
132+
std::pair<CatchNode, const BraceStmtScope *>
133+
FunctionBodyScope::getCatchNodeBody() const {
134+
const BraceStmtScope *body = nullptr;
135+
const auto &children = getChildren();
136+
if (!children.empty()) {
137+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
138+
assert(body && "Not a brace statement?");
139+
}
140+
return { const_cast<AbstractFunctionDecl *>(decl), body };
141+
}
142+
143+
/// Determine whether this is an empty brace statement, which doesn't have a
144+
/// node associated with it.
145+
static bool isEmptyBraceStmt(Stmt *stmt) {
146+
if (auto braceStmt = dyn_cast_or_null<BraceStmt>(stmt))
147+
return braceStmt->empty();
148+
149+
return false;
150+
}
151+
152+
std::pair<CatchNode, const BraceStmtScope *>
153+
DoCatchStmtScope::getCatchNodeBody() const {
154+
const BraceStmtScope *body = nullptr;
155+
const auto &children = getChildren();
156+
if (!children.empty() && !isEmptyBraceStmt(stmt->getBody())) {
157+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
158+
assert(body && "Not a brace statement?");
159+
}
160+
return { const_cast<DoCatchStmt *>(stmt), body };
161+
}
162+
104163
SourceManager &ASTScopeImpl::getSourceManager() const {
105164
return getASTContext().SourceMgr;
106165
}

lib/AST/ASTScopeLookup.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,30 @@ void ASTScopeImpl::lookupEnclosingMacroScope(
712712

713713
} while ((scope = scope->getParent().getPtrOrNull()));
714714
}
715+
716+
CatchNode ASTScopeImpl::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
717+
auto sourceFile = module->getSourceFileContainingLocation(loc);
718+
if (!sourceFile)
719+
return nullptr;
720+
721+
auto *fileScope = sourceFile->getScope().impl;
722+
const auto *innermost = fileScope->findInnermostEnclosingScope(
723+
module, loc, nullptr);
724+
ASTScopeAssert(innermost->getWasExpanded(),
725+
"If looking in a scope, it must have been expanded.");
726+
727+
// Look for a body scope that's the
728+
const BraceStmtScope *innerBodyScope = nullptr;
729+
for (auto scope = innermost; scope; scope = scope->getParent().getPtrOrNull()) {
730+
// If we are at a catch node and in the body of the region from which that
731+
// node catches thrown errors, we have our result.
732+
auto caught = scope->getCatchNodeBody();
733+
if (caught.first && caught.second == innerBodyScope) {
734+
return caught.first;
735+
}
736+
737+
innerBodyScope = scope->getAsBraceStmtScope().getPtrOrNull();
738+
}
739+
740+
return nullptr;
741+
}

lib/AST/Decl.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -11398,3 +11398,32 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1139811398
return getParentOf(
1139911399
expansion->getPoundLoc(), expansion->getDeclContext());
1140011400
}
11401+
11402+
llvm::Optional<Type>
11403+
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
11404+
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
11405+
if (auto thrownError = func->getEffectiveThrownErrorType())
11406+
return func->mapTypeIntoContext(*thrownError);
11407+
11408+
return llvm::None;
11409+
}
11410+
11411+
if (auto closure = dyn_cast<AbstractClosureExpr *>()) {
11412+
if (closure->getType())
11413+
return closure->getEffectiveThrownType();
11414+
11415+
// FIXME: Should we lazily compute this?
11416+
return llvm::None;
11417+
}
11418+
11419+
auto doCatch = get<DoCatchStmt *>();
11420+
if (auto thrownError = doCatch->getCaughtErrorType()) {
11421+
if (thrownError->isNever())
11422+
return llvm::None;
11423+
11424+
return thrownError;
11425+
}
11426+
11427+
// If we haven't computed the error type yet, do so now.
11428+
return ctx.getErrorExistentialType();
11429+
}

lib/AST/Stmt.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,15 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
473473
}
474474

475475
Type DoCatchStmt::getCaughtErrorType() const {
476-
return getCatches()
476+
auto firstPattern = getCatches()
477477
.front()
478478
->getCaseLabelItems()
479479
.front()
480-
.getPattern()
481-
->getType();
480+
.getPattern();
481+
if (firstPattern->hasType())
482+
return firstPattern->getType();
483+
484+
return Type();
482485
}
483486

484487
void LabeledConditionalStmt::setCond(StmtCondition e) {

lib/Sema/TypeCheckStmt.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -1195,11 +1195,16 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
11951195
// Coerce the operand to the exception type.
11961196
auto E = TS->getSubExpr();
11971197

1198+
// Look up the catch node for this "throw" to determine the error type.
1199+
CatchNode catchNode = ASTScope::lookupCatchNode(
1200+
DC->getParentModule(), TS->getThrowLoc());
11981201
Type errorType;
1199-
if (auto TheFunc = AnyFunctionRef::fromDeclContext(DC)) {
1200-
errorType = TheFunc->getThrownErrorType();
1202+
if (catchNode) {
1203+
errorType = catchNode.getThrownErrorTypeInContext(getASTContext())
1204+
.value_or(Type());
12011205
}
12021206

1207+
// If there was no error type, use 'any Error'. We'll check it later.
12031208
if (!errorType) {
12041209
errorType = getASTContext().getErrorExistentialType();
12051210
}

test/stmt/typed_throws.swift

+16-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ func testDoCatchMultiErrorType() {
7474
try doSomething()
7575
try doHomework()
7676
} catch .failed { // expected-error{{type 'any Error' has no member 'failed'}}
77-
77+
7878
} catch {
7979
let _: Int = error // expected-error{{cannot convert value of type 'any Error' to specified type 'Int'}}
80-
}
80+
}
8181
}
8282

8383
func testDoCatchRethrowsUntyped() throws {
@@ -96,7 +96,7 @@ func testDoCatchRethrowsTyped() throws(HomeworkError) {
9696
do {
9797
try doSomething()
9898
} catch .failed {
99-
99+
100100
} // expected-error{{thrown expression type 'MyError' cannot be converted to error type 'HomeworkError'}}
101101

102102
do {
@@ -114,8 +114,20 @@ func testDoCatchRethrowsTyped() throws(HomeworkError) {
114114
} // okay, the thrown 'any Error' has been caught
115115
}
116116

117-
func testTryIncompatibleTyped() throws(HomeworkError) {
117+
func testTryIncompatibleTyped(cond: Bool) throws(HomeworkError) {
118118
try doHomework() // okay
119119

120120
try doSomething() // FIXME: should error
121+
122+
do {
123+
if cond {
124+
throw .dogAteIt // expected-error{{type 'any Error' has no member 'dogAteIt'}}
125+
} else {
126+
try doSomething()
127+
}
128+
} catch let error as Never {
129+
// expected-warning@-1{{'catch' block is unreachable because no errors are thrown in 'do' block}}
130+
// expected-warning@-2{{'as' test is always true}}
131+
throw .forgot
132+
}
121133
}

0 commit comments

Comments
 (0)