@@ -35,6 +35,10 @@ static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
35
35
static constexpr const char *kCreateGroup = " mlirAsyncRuntimeCreateGroup" ;
36
36
static constexpr const char *kEmplaceToken = " mlirAsyncRuntimeEmplaceToken" ;
37
37
static constexpr const char *kEmplaceValue = " mlirAsyncRuntimeEmplaceValue" ;
38
+ static constexpr const char *kSetTokenError = " mlirAsyncRuntimeSetTokenError" ;
39
+ static constexpr const char *kSetValueError = " mlirAsyncRuntimeSetValueError" ;
40
+ static constexpr const char *kIsTokenError = " mlirAsyncRuntimeIsTokenError" ;
41
+ static constexpr const char *kIsValueError = " mlirAsyncRuntimeIsValueError" ;
38
42
static constexpr const char *kAwaitToken = " mlirAsyncRuntimeAwaitToken" ;
39
43
static constexpr const char *kAwaitValue = " mlirAsyncRuntimeAwaitValue" ;
40
44
static constexpr const char *kAwaitGroup = " mlirAsyncRuntimeAwaitAllInGroup" ;
@@ -101,6 +105,26 @@ struct AsyncAPI {
101
105
return FunctionType::get (ctx, {value}, {});
102
106
}
103
107
108
+ static FunctionType setTokenErrorFunctionType (MLIRContext *ctx) {
109
+ return FunctionType::get (ctx, {TokenType::get (ctx)}, {});
110
+ }
111
+
112
+ static FunctionType setValueErrorFunctionType (MLIRContext *ctx) {
113
+ auto value = opaquePointerType (ctx);
114
+ return FunctionType::get (ctx, {value}, {});
115
+ }
116
+
117
+ static FunctionType isTokenErrorFunctionType (MLIRContext *ctx) {
118
+ auto i1 = IntegerType::get (ctx, 1 );
119
+ return FunctionType::get (ctx, {TokenType::get (ctx)}, {i1});
120
+ }
121
+
122
+ static FunctionType isValueErrorFunctionType (MLIRContext *ctx) {
123
+ auto value = opaquePointerType (ctx);
124
+ auto i1 = IntegerType::get (ctx, 1 );
125
+ return FunctionType::get (ctx, {value}, {i1});
126
+ }
127
+
104
128
static FunctionType awaitTokenFunctionType (MLIRContext *ctx) {
105
129
return FunctionType::get (ctx, {TokenType::get (ctx)}, {});
106
130
}
@@ -173,6 +197,10 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
173
197
addFuncDecl (kCreateGroup , AsyncAPI::createGroupFunctionType (ctx));
174
198
addFuncDecl (kEmplaceToken , AsyncAPI::emplaceTokenFunctionType (ctx));
175
199
addFuncDecl (kEmplaceValue , AsyncAPI::emplaceValueFunctionType (ctx));
200
+ addFuncDecl (kSetTokenError , AsyncAPI::setTokenErrorFunctionType (ctx));
201
+ addFuncDecl (kSetValueError , AsyncAPI::setValueErrorFunctionType (ctx));
202
+ addFuncDecl (kIsTokenError , AsyncAPI::isTokenErrorFunctionType (ctx));
203
+ addFuncDecl (kIsValueError , AsyncAPI::isValueErrorFunctionType (ctx));
176
204
addFuncDecl (kAwaitToken , AsyncAPI::awaitTokenFunctionType (ctx));
177
205
addFuncDecl (kAwaitValue , AsyncAPI::awaitValueFunctionType (ctx));
178
206
addFuncDecl (kAwaitGroup , AsyncAPI::awaitGroupFunctionType (ctx));
@@ -560,17 +588,53 @@ class RuntimeSetAvailableOpLowering
560
588
matchAndRewrite (RuntimeSetAvailableOp op, ArrayRef<Value> operands,
561
589
ConversionPatternRewriter &rewriter) const override {
562
590
Type operandType = op.operand ().getType ();
591
+ rewriter.replaceOpWithNewOp <CallOp>(
592
+ op, operandType.isa <TokenType>() ? kEmplaceToken : kEmplaceValue ,
593
+ TypeRange (), operands);
594
+ return success ();
595
+ }
596
+ };
597
+ } // namespace
563
598
564
- if (operandType.isa <TokenType>() || operandType.isa <ValueType>()) {
565
- rewriter.create <CallOp>(op->getLoc (),
566
- operandType.isa <TokenType>() ? kEmplaceToken
567
- : kEmplaceValue ,
568
- TypeRange (), operands);
569
- rewriter.eraseOp (op);
570
- return success ();
571
- }
599
+ // ===----------------------------------------------------------------------===//
600
+ // Convert async.runtime.set_error to the corresponding runtime API call.
601
+ // ===----------------------------------------------------------------------===//
572
602
573
- return rewriter.notifyMatchFailure (op, " unsupported async type" );
603
+ namespace {
604
+ class RuntimeSetErrorOpLowering
605
+ : public OpConversionPattern<RuntimeSetErrorOp> {
606
+ public:
607
+ using OpConversionPattern::OpConversionPattern;
608
+
609
+ LogicalResult
610
+ matchAndRewrite (RuntimeSetErrorOp op, ArrayRef<Value> operands,
611
+ ConversionPatternRewriter &rewriter) const override {
612
+ Type operandType = op.operand ().getType ();
613
+ rewriter.replaceOpWithNewOp <CallOp>(
614
+ op, operandType.isa <TokenType>() ? kSetTokenError : kSetValueError ,
615
+ TypeRange (), operands);
616
+ return success ();
617
+ }
618
+ };
619
+ } // namespace
620
+
621
+ // ===----------------------------------------------------------------------===//
622
+ // Convert async.runtime.is_error to the corresponding runtime API call.
623
+ // ===----------------------------------------------------------------------===//
624
+
625
+ namespace {
626
+ class RuntimeIsErrorOpLowering : public OpConversionPattern <RuntimeIsErrorOp> {
627
+ public:
628
+ using OpConversionPattern::OpConversionPattern;
629
+
630
+ LogicalResult
631
+ matchAndRewrite (RuntimeIsErrorOp op, ArrayRef<Value> operands,
632
+ ConversionPatternRewriter &rewriter) const override {
633
+ Type operandType = op.operand ().getType ();
634
+ rewriter.replaceOpWithNewOp <CallOp>(
635
+ op, operandType.isa <TokenType>() ? kIsTokenError : kIsValueError ,
636
+ rewriter.getI1Type (), operands);
637
+ return success ();
574
638
}
575
639
};
576
640
} // namespace
@@ -889,7 +953,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
889
953
patterns.add <ReturnOpOpConversion>(converter, ctx);
890
954
891
955
// Lower async.runtime operations to the async runtime API calls.
892
- patterns.add <RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
956
+ patterns.add <RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
957
+ RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
893
958
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
894
959
RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
895
960
RuntimeDropRefOpLowering>(converter, ctx);
0 commit comments