Skip to content

Commit c30ab6c

Browse files
committed
[mlir] Transform scf.parallel to scf.for + async.execute
Depends On D89958 1. Adds `async.group`/`async.awaitall` to group together multiple async tokens/values 2. Rewrite scf.parallel operation into multiple concurrent async.execute operations over non overlapping subranges of the original loop. Example: ``` scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { "do_some_compute"(%i, %j): () -> () } ``` Converted to: ``` %c0 = constant 0 : index %c1 = constant 1 : index // Compute blocks sizes for each induction variable. %num_blocks_i = ... : index %num_blocks_j = ... : index %block_size_i = ... : index %block_size_j = ... : index // Create an async group to track async execute ops. %group = async.create_group scf.for %bi = %c0 to %num_blocks_i step %c1 { %block_start_i = ... : index %block_end_i = ... : index scf.for %bj = %c0 t0 %num_blocks_j step %c1 { %block_start_j = ... : index %block_end_j = ... : index // Execute the body of original parallel operation for the current // block. %token = async.execute { scf.for %i = %block_start_i to %block_end_i step %si { scf.for %j = %block_start_j to %block_end_j step %sj { "do_some_compute"(%i, %j): () -> () } } } // Add produced async token to the group. async.add_to_group %token, %group } } // Await completion of all async.execute operations. async.await_all %group ``` In this example outer loop launches inner block level loops as separate async execute operations which will be executed concurrently. At the end it waits for the completiom of all async execute operations. Reviewed By: ftynse, mehdi_amini Differential Revision: https://reviews.llvm.org/D89963
1 parent 7da0d0a commit c30ab6c

File tree

22 files changed

+1066
-23
lines changed

22 files changed

+1066
-23
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
add_subdirectory(IR)
2+
3+
set(LLVM_TARGET_DEFINITIONS Passes.td)
4+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Async)
5+
add_public_tablegen_target(MLIRAsyncPassIncGen)
6+
7+
add_mlir_doc(Passes -gen-pass-doc AsyncPasses ./)

mlir/include/mlir/Dialect/Async/IR/Async.h

+6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class ValueType
4747
Type getValueType();
4848
};
4949

50+
/// The group type to represent async tokens or values grouped together.
51+
class GroupType : public Type::TypeBase<GroupType, Type, TypeStorage> {
52+
public:
53+
using Base::Base;
54+
};
55+
5056
} // namespace async
5157
} // namespace mlir
5258

mlir/include/mlir/Dialect/Async/IR/AsyncBase.td

+10
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ class Async_ValueType<Type type>
5656
Type valueType = type;
5757
}
5858

59+
def Async_GroupType : DialectType<AsyncDialect,
60+
CPred<"$_self.isa<::mlir::async::GroupType>()">, "group type">,
61+
BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> {
62+
let typeDescription = [{
63+
`async.group` represent a set of async tokens or values and allows to
64+
execute async operations on all of them together (e.g. wait for the
65+
completion of all/any of them).
66+
}];
67+
}
68+
5969
def Async_AnyValueType : DialectType<AsyncDialect,
6070
CPred<"$_self.isa<::mlir::async::ValueType>()">,
6171
"async value type">;

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

+90-4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@ def Async_ExecuteOp :
8181
let printer = [{ return ::print(p, *this); }];
8282
let parser = [{ return ::parse$cppClass(parser, result); }];
8383
let verifier = [{ return ::verify(*this); }];
84+
85+
let skipDefaultBuilders = 1;
86+
let builders = [
87+
OpBuilderDAG<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies,
88+
"ValueRange":$operands,
89+
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
90+
"nullptr">:$bodyBuilder)>,
91+
];
92+
93+
let extraClassDeclaration = [{
94+
using BodyBuilderFn =
95+
function_ref<void(OpBuilder &, Location, ValueRange)>;
96+
97+
}];
8498
}
8599

86100
def Async_YieldOp :
@@ -93,12 +107,12 @@ def Async_YieldOp :
93107

94108
let arguments = (ins Variadic<AnyType>:$operands);
95109

96-
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
110+
let assemblyFormat = "($operands^ `:` type($operands))? attr-dict";
97111

98112
let verifier = [{ return ::verify(*this); }];
99113
}
100114

101-
def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
115+
def Async_AwaitOp : Async_Op<"await"> {
102116
let summary = "waits for the argument to become ready";
103117
let description = [{
104118
The `async.await` operation waits until the argument becomes ready, and for
@@ -133,12 +147,84 @@ def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
133147
}];
134148

135149
let assemblyFormat = [{
136-
attr-dict $operand `:` custom<AwaitResultType>(
150+
$operand `:` custom<AwaitResultType>(
137151
type($operand), type($result)
138-
)
152+
) attr-dict
139153
}];
140154

141155
let verifier = [{ return ::verify(*this); }];
142156
}
143157

158+
def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
159+
let summary = "creates an empty async group";
160+
let description = [{
161+
The `async.create_group` allocates an empty async group. Async tokens or
162+
values can be added to this group later.
163+
164+
Example:
165+
166+
```mlir
167+
%0 = async.create_group
168+
...
169+
async.await_all %0
170+
```
171+
}];
172+
173+
let arguments = (ins );
174+
let results = (outs Async_GroupType:$result);
175+
176+
let assemblyFormat = "attr-dict";
177+
}
178+
179+
def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
180+
let summary = "adds and async token or value to the group";
181+
let description = [{
182+
The `async.add_to_group` adds an async token or value to the async group.
183+
Returns the rank of the added element in the group. This rank is fixed
184+
for the group lifetime.
185+
186+
Example:
187+
188+
```mlir
189+
%0 = async.create_group
190+
%1 = ... : !async.token
191+
%2 = async.add_to_group %1, %0 : !async.token
192+
```
193+
}];
194+
195+
let arguments = (ins Async_AnyValueOrTokenType:$operand,
196+
Async_GroupType:$group);
197+
let results = (outs Index:$rank);
198+
199+
let assemblyFormat = "$operand `,` $group `:` type($operand) attr-dict";
200+
}
201+
202+
def Async_AwaitAllOp : Async_Op<"await_all", []> {
203+
let summary = "waits for the all async tokens or values in the group to "
204+
"become ready";
205+
let description = [{
206+
The `async.await_all` operation waits until all the tokens or values in the
207+
group become ready.
208+
209+
Example:
210+
211+
```mlir
212+
%0 = async.create_group
213+
214+
%1 = ... : !async.token
215+
%2 = async.add_to_group %1, %0 : !async.token
216+
217+
%3 = ... : !async.token
218+
%4 = async.add_to_group %2, %0 : !async.token
219+
220+
async.await_all %0
221+
```
222+
}];
223+
224+
let arguments = (ins Async_GroupType:$operand);
225+
let results = (outs);
226+
227+
let assemblyFormat = "$operand attr-dict";
228+
}
229+
144230
#endif // ASYNC_OPS
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- Passes.h - Async pass entry points -----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines prototypes that expose pass constructors.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_ASYNC_PASSES_H_
14+
#define MLIR_DIALECT_ASYNC_PASSES_H_
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
20+
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
21+
22+
//===----------------------------------------------------------------------===//
23+
// Registration
24+
//===----------------------------------------------------------------------===//
25+
26+
/// Generate the code for registering passes.
27+
#define GEN_PASS_REGISTRATION
28+
#include "mlir/Dialect/Async/Passes.h.inc"
29+
30+
} // namespace mlir
31+
32+
#endif // MLIR_DIALECT_ASYNC_PASSES_H_
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- Passes.td - Async pass definition file -------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ASYNC_PASSES
10+
#define MLIR_DIALECT_ASYNC_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
15+
let summary = "Convert scf.parallel operations to multiple async regions "
16+
"executed concurrently for non-overlapping iteration ranges";
17+
let constructor = "mlir::createAsyncParallelForPass()";
18+
let options = [
19+
Option<"numConcurrentAsyncExecute", "num-concurrent-async-execute",
20+
"int32_t", /*default=*/"4",
21+
"The number of async.execute operations that will be used for concurrent "
22+
"loop execution.">
23+
];
24+
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
25+
}
26+
27+
#endif // MLIR_DIALECT_ASYNC_PASSES

mlir/include/mlir/ExecutionEngine/AsyncRuntime.h

+20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#ifndef MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
1515
#define MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
1616

17+
#include <stdint.h>
18+
1719
#ifdef _WIN32
1820
#ifndef MLIR_ASYNCRUNTIME_EXPORT
1921
#ifdef mlir_async_runtime_EXPORTS
@@ -37,6 +39,9 @@
3739
// Runtime implementation of `async.token` data type.
3840
typedef struct AsyncToken MLIR_AsyncToken;
3941

42+
// Runtime implementation of `async.group` data type.
43+
typedef struct AsyncGroup MLIR_AsyncGroup;
44+
4045
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
4146
// function is a coroutine handle and a resume function that continue coroutine
4247
// execution from a suspension point.
@@ -46,6 +51,12 @@ using CoroResume = void (*)(void *); // coroutine resume function
4651
// Create a new `async.token` in not-ready state.
4752
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
4853

54+
// Create a new `async.group` in empty state.
55+
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
56+
57+
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
58+
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
59+
4960
// Switches `async.token` to ready state and runs all awaiters.
5061
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
5162
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
@@ -54,6 +65,10 @@ mlirAsyncRuntimeEmplaceToken(AsyncToken *);
5465
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
5566
mlirAsyncRuntimeAwaitToken(AsyncToken *);
5667

68+
// Blocks the caller thread until the elements in the group become ready.
69+
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
70+
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
71+
5772
// Executes the task (coro handle + resume function) in one of the threads
5873
// managed by the runtime.
5974
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
@@ -64,6 +79,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
6479
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
6580
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
6681

82+
// Executes the task (coro handle + resume function) in one of the threads
83+
// managed by the runtime after the all members of the group become ready.
84+
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
85+
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume);
86+
6787
//===----------------------------------------------------------------------===//
6888
// Small async runtime support library for testing.
6989
//===----------------------------------------------------------------------===//

mlir/include/mlir/InitAllPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Conversion/Passes.h"
1818
#include "mlir/Dialect/Affine/Passes.h"
19+
#include "mlir/Dialect/Async/Passes.h"
1920
#include "mlir/Dialect/GPU/Passes.h"
2021
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
2122
#include "mlir/Dialect/Linalg/Passes.h"
@@ -47,6 +48,7 @@ inline void registerAllPasses() {
4748

4849
// Dialect passes
4950
registerAffinePasses();
51+
registerAsyncPasses();
5052
registerGPUPasses();
5153
registerLinalgPasses();
5254
LLVM::registerLLVMPasses();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import sys
2+
3+
# No JIT on win32.
4+
if sys.platform == 'win32':
5+
config.unsupported = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-opt %s -async-parallel-for \
2+
// RUN: -convert-async-to-llvm \
3+
// RUN: -convert-scf-to-std \
4+
// RUN: -convert-std-to-llvm \
5+
// RUN: | mlir-cpu-runner \
6+
// RUN: -e entry -entry-point-result=void -O0 \
7+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
8+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
9+
// RUN: | FileCheck %s --dump-input=always
10+
11+
func @entry() {
12+
%c0 = constant 0.0 : f32
13+
%c1 = constant 1 : index
14+
%c2 = constant 2 : index
15+
%c3 = constant 3 : index
16+
17+
%lb = constant 0 : index
18+
%ub = constant 9 : index
19+
20+
%A = alloc() : memref<9xf32>
21+
%U = memref_cast %A : memref<9xf32> to memref<*xf32>
22+
23+
// 1. %i = (0) to (9) step (1)
24+
scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
25+
%0 = index_cast %i : index to i32
26+
%1 = sitofp %0 : i32 to f32
27+
store %1, %A[%i] : memref<9xf32>
28+
}
29+
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8]
30+
call @print_memref_f32(%U): (memref<*xf32>) -> ()
31+
32+
scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
33+
store %c0, %A[%i] : memref<9xf32>
34+
}
35+
36+
// 2. %i = (0) to (9) step (2)
37+
scf.parallel (%i) = (%lb) to (%ub) step (%c2) {
38+
%0 = index_cast %i : index to i32
39+
%1 = sitofp %0 : i32 to f32
40+
store %1, %A[%i] : memref<9xf32>
41+
}
42+
// CHECK: [0, 0, 2, 0, 4, 0, 6, 0, 8]
43+
call @print_memref_f32(%U): (memref<*xf32>) -> ()
44+
45+
scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
46+
store %c0, %A[%i] : memref<9xf32>
47+
}
48+
49+
// 3. %i = (-20) to (-11) step (3)
50+
%lb0 = constant -20 : index
51+
%ub0 = constant -11 : index
52+
scf.parallel (%i) = (%lb0) to (%ub0) step (%c3) {
53+
%0 = index_cast %i : index to i32
54+
%1 = sitofp %0 : i32 to f32
55+
%2 = constant 20 : index
56+
%3 = addi %i, %2 : index
57+
store %1, %A[%3] : memref<9xf32>
58+
}
59+
// CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0]
60+
call @print_memref_f32(%U): (memref<*xf32>) -> ()
61+
62+
dealloc %A : memref<9xf32>
63+
return
64+
}
65+
66+
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }

0 commit comments

Comments
 (0)