Skip to content

Commit 39957aa

Browse files
committedMay 27, 2021
[mlir] Add error state and error propagation to async runtime values
Depends On D103102 Not yet implemented: 1. Error handling after synchronous await 2. Error handling for async groups Will be addressed in the followup PRs Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D103109
1 parent 4fbc66c commit 39957aa

File tree

9 files changed

+523
-45
lines changed

9 files changed

+523
-45
lines changed
 

‎mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

+30-4
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
343343
}
344344

345345
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
346-
let summary = "switches token or value available state";
346+
let summary = "switches token or value to available state";
347347
let description = [{
348348
The `async.runtime.set_available` operation switches async token or value
349349
state to available.
@@ -353,11 +353,37 @@ def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
353353
let assemblyFormat = "$operand attr-dict `:` type($operand)";
354354
}
355355

356+
def Async_RuntimeSetErrorOp : Async_Op<"runtime.set_error"> {
357+
let summary = "switches token or value to error state";
358+
let description = [{
359+
The `async.runtime.set_error` operation switches async token or value
360+
state to error.
361+
}];
362+
363+
let arguments = (ins Async_AnyValueOrTokenType:$operand);
364+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
365+
}
366+
367+
def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> {
368+
let summary = "returns true if token, value or group is in error state";
369+
let description = [{
370+
The `async.runtime.is_error` operation returns true if the token, value or
371+
group (any of the async runtime values) is in the error state. It is the
372+
caller responsibility to check error state after the call to `await` or
373+
resuming after `await_and_resume`.
374+
}];
375+
376+
let arguments = (ins Async_AnyValueOrTokenType:$operand);
377+
let results = (outs I1:$is_error);
378+
379+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
380+
}
381+
356382
def Async_RuntimeAwaitOp : Async_Op<"runtime.await"> {
357383
let summary = "blocks the caller thread until the operand becomes available";
358384
let description = [{
359385
The `async.runtime.await` operation blocks the caller thread until the
360-
operand becomes available.
386+
operand becomes available or error.
361387
}];
362388

363389
let arguments = (ins Async_AnyAsyncType:$operand);
@@ -379,8 +405,8 @@ def Async_RuntimeAwaitAndResumeOp : Async_Op<"runtime.await_and_resume"> {
379405
let summary = "awaits the async operand and resumes the coroutine";
380406
let description = [{
381407
The `async.runtime.await_and_resume` operation awaits for the operand to
382-
become available and resumes the coroutine on a thread managed by the
383-
runtime.
408+
become available or error and resumes the coroutine on a thread managed by
409+
the runtime.
384410
}];
385411

386412
let arguments = (ins Async_AnyAsyncType:$operand,

‎mlir/include/mlir/ExecutionEngine/AsyncRuntime.h

+12
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *);
7676
// Switches `async.value` to ready state and runs all awaiters.
7777
extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *);
7878

79+
// Switches `async.token` to error state and runs all awaiters.
80+
extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *);
81+
82+
// Switches `async.value` to error state and runs all awaiters.
83+
extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *);
84+
85+
// Returns true if token is in the error state.
86+
extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *);
87+
88+
// Returns true if value is in the error state.
89+
extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *);
90+
7991
// Blocks the caller thread until the token becomes ready.
8092
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *);
8193

‎mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

+75-10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
3535
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
3636
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
3737
static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
38+
static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
39+
static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
40+
static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
41+
static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
3842
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
3943
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
4044
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
@@ -101,6 +105,26 @@ struct AsyncAPI {
101105
return FunctionType::get(ctx, {value}, {});
102106
}
103107

108+
static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
109+
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
110+
}
111+
112+
static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
113+
auto value = opaquePointerType(ctx);
114+
return FunctionType::get(ctx, {value}, {});
115+
}
116+
117+
static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
118+
auto i1 = IntegerType::get(ctx, 1);
119+
return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
120+
}
121+
122+
static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
123+
auto value = opaquePointerType(ctx);
124+
auto i1 = IntegerType::get(ctx, 1);
125+
return FunctionType::get(ctx, {value}, {i1});
126+
}
127+
104128
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
105129
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106130
}
@@ -173,6 +197,10 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
173197
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
174198
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
175199
addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
200+
addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
201+
addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
202+
addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
203+
addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
176204
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
177205
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
178206
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
@@ -560,17 +588,53 @@ class RuntimeSetAvailableOpLowering
560588
matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
561589
ConversionPatternRewriter &rewriter) const override {
562590
Type operandType = op.operand().getType();
591+
rewriter.replaceOpWithNewOp<CallOp>(
592+
op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
593+
TypeRange(), operands);
594+
return success();
595+
}
596+
};
597+
} // namespace
563598

564-
if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
565-
rewriter.create<CallOp>(op->getLoc(),
566-
operandType.isa<TokenType>() ? kEmplaceToken
567-
: kEmplaceValue,
568-
TypeRange(), operands);
569-
rewriter.eraseOp(op);
570-
return success();
571-
}
599+
//===----------------------------------------------------------------------===//
600+
// Convert async.runtime.set_error to the corresponding runtime API call.
601+
//===----------------------------------------------------------------------===//
572602

573-
return rewriter.notifyMatchFailure(op, "unsupported async type");
603+
namespace {
604+
class RuntimeSetErrorOpLowering
605+
: public OpConversionPattern<RuntimeSetErrorOp> {
606+
public:
607+
using OpConversionPattern::OpConversionPattern;
608+
609+
LogicalResult
610+
matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
611+
ConversionPatternRewriter &rewriter) const override {
612+
Type operandType = op.operand().getType();
613+
rewriter.replaceOpWithNewOp<CallOp>(
614+
op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
615+
TypeRange(), operands);
616+
return success();
617+
}
618+
};
619+
} // namespace
620+
621+
//===----------------------------------------------------------------------===//
622+
// Convert async.runtime.is_error to the corresponding runtime API call.
623+
//===----------------------------------------------------------------------===//
624+
625+
namespace {
626+
class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
627+
public:
628+
using OpConversionPattern::OpConversionPattern;
629+
630+
LogicalResult
631+
matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
632+
ConversionPatternRewriter &rewriter) const override {
633+
Type operandType = op.operand().getType();
634+
rewriter.replaceOpWithNewOp<CallOp>(
635+
op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
636+
rewriter.getI1Type(), operands);
637+
return success();
574638
}
575639
};
576640
} // namespace
@@ -889,7 +953,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
889953
patterns.add<ReturnOpOpConversion>(converter, ctx);
890954

891955
// Lower async.runtime operations to the async runtime API calls.
892-
patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
956+
patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
957+
RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
893958
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
894959
RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
895960
RuntimeDropRefOpLowering>(converter, ctx);

0 commit comments

Comments
 (0)