Skip to content

Commit 2eff566

Browse files
author
Stephan Herhut
committed
[MLIR] Add and, or, xor, min, max too gpu.all_reduce and the nvvm lowering
Summary: This patch add some builtin operation for the gpu.all_reduce ops. - for Integer only: `and`, `or`, `xor` - for Float and Integer: `min`, `max` This is useful for higher level dialect like OpenACC or OpenMP that can lower to the GPU dialect. Differential Revision: https://reviews.llvm.org/D75766
1 parent 7fb562c commit 2eff566

File tree

13 files changed

+615
-11
lines changed

13 files changed

+615
-11
lines changed

mlir/include/mlir/Dialect/GPU/GPUOps.td

+13-3
Original file line numberDiff line numberDiff line change
@@ -482,15 +482,25 @@ def GPU_YieldOp : GPU_Op<"yield", [Terminator]>,
482482
}];
483483
}
484484

485-
// These mirror the XLA ComparisonDirection enum.
485+
// add, mul mirror the XLA ComparisonDirection enum.
486486
def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
487+
def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
488+
def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
489+
def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
487490
def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
491+
def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
492+
def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;
488493

489494
def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
490495
"built-in reduction operations supported by gpu.allreduce.",
491496
[
492497
GPU_AllReduceOpAdd,
498+
GPU_AllReduceOpAnd,
499+
GPU_AllReduceOpMax,
500+
GPU_AllReduceOpMin,
493501
GPU_AllReduceOpMul,
502+
GPU_AllReduceOpOr,
503+
GPU_AllReduceOpXor
494504
]>;
495505

496506
def GPU_AllReduceOp : GPU_Op<"all_reduce",
@@ -514,8 +524,8 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
514524
```
515525
compute the sum of each work item's %0 value. The first version specifies
516526
the accumulation as operation, whereas the second version specifies the
517-
accumulation as code region. The accumulation operation must either be
518-
`add` or `mul`.
527+
accumulation as code region. The accumulation operation must be one of:
528+
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
519529

520530
Either none or all work items of a workgroup need to execute this op
521531
in convergence.

mlir/include/mlir/ExecutionEngine/RunnerUtils.h

+2
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
211211
extern "C" MLIR_RUNNERUTILS_EXPORT void
212212
_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
213213

214+
extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank,
215+
void *ptr);
214216
extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank,
215217
void *ptr);
216218

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

+36-3
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,51 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
123123
return isFloatingPoint ? getFactory<LLVM::FMulOp>()
124124
: getFactory<LLVM::MulOp>();
125125
}
126+
if (opName == "and") {
127+
return getFactory<LLVM::AndOp>();
128+
}
129+
if (opName == "or") {
130+
return getFactory<LLVM::OrOp>();
131+
}
132+
if (opName == "xor") {
133+
return getFactory<LLVM::XOrOp>();
134+
}
135+
if (opName == "max") {
136+
return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
137+
LLVM::FCmpPredicate::ugt>()
138+
: getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
139+
LLVM::ICmpPredicate::ugt>();
140+
}
141+
if (opName == "min") {
142+
return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
143+
LLVM::FCmpPredicate::ult>()
144+
: getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
145+
LLVM::ICmpPredicate::ult>();
146+
}
126147

127148
return AccumulatorFactory();
128149
}
129150

130151
/// Returns an accumulator factory that creates an op of type T.
131-
template <typename T> AccumulatorFactory getFactory() const {
152+
template <typename T>
153+
AccumulatorFactory getFactory() const {
132154
return [](Location loc, Value lhs, Value rhs,
133155
ConversionPatternRewriter &rewriter) {
134156
return rewriter.create<T>(loc, lhs.getType(), lhs, rhs);
135157
};
136158
}
137159

160+
/// Returns an accumulator for comparaison such as min, max. T is the type
161+
/// of the compare op.
162+
template <typename T, typename PredicateEnum, PredicateEnum predicate>
163+
AccumulatorFactory getCmpFactory() const {
164+
return [](Location loc, Value lhs, Value rhs,
165+
ConversionPatternRewriter &rewriter) {
166+
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
167+
return rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
168+
};
169+
}
170+
138171
/// Creates an all_reduce across the block.
139172
///
140173
/// First reduce the elements within a warp. The first thread of each warp
@@ -705,9 +738,9 @@ void mlir::populateGpuToNVVMConversionPatterns(
705738
GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
706739
GPUReturnOpLowering>(converter);
707740
patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
708-
"__nv_fabs");
741+
"__nv_fabs");
709742
patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf",
710-
"__nv_ceil");
743+
"__nv_ceil");
711744
patterns.insert<OpToFuncCallLowering<CosOp>>(converter, "__nv_cosf",
712745
"__nv_cos");
713746
patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
148148
}
149149
if (yieldCount == 0)
150150
return allReduce.emitError("expected gpu.yield op in region");
151+
} else {
152+
StringRef opName = *allReduce.op();
153+
if ((opName == "and" || opName == "or" || opName == "xor") &&
154+
!allReduce.getType().isa<IntegerType>()) {
155+
return allReduce.emitError()
156+
<< '`' << opName << '`'
157+
<< " accumulator is only compatible with Integer type";
158+
}
151159
}
152160
return success();
153161
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,25 @@ struct GpuAllReduceRewriter {
212212
return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
213213
if (opName == "mul")
214214
return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
215+
if (opName == "and") {
216+
return getFactory<AndOp>();
217+
}
218+
if (opName == "or") {
219+
return getFactory<OrOp>();
220+
}
221+
if (opName == "xor") {
222+
return getFactory<XOrOp>();
223+
}
224+
if (opName == "max") {
225+
return isFloatingPoint
226+
? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
227+
: getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
228+
}
229+
if (opName == "min") {
230+
return isFloatingPoint
231+
? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
232+
: getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
233+
}
215234
return AccumulatorFactory();
216235
}
217236

@@ -222,6 +241,16 @@ struct GpuAllReduceRewriter {
222241
};
223242
}
224243

244+
/// Returns an accumulator for comparaison such as min, max. T is the type
245+
/// of the compare op.
246+
template <typename T, typename PredicateEnum, PredicateEnum predicate>
247+
AccumulatorFactory getCmpFactory() const {
248+
return [&](Value lhs, Value rhs) {
249+
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
250+
return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
251+
};
252+
}
253+
225254
/// Creates an if-block skeleton and calls the two factories to generate the
226255
/// ops in the `then` and `else` block..
227256
///

mlir/lib/ExecutionEngine/RunnerUtils.cpp

+24-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
2727

2828
extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
2929
printUnrankedMemRefMetaData(std::cout, *M);
30-
int rank = M->rank;
30+
int64_t rank = M->rank;
3131
void *ptr = M->descriptor;
3232

3333
switch (rank) {
@@ -41,9 +41,25 @@ extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
4141
}
4242
}
4343

44+
extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType<int32_t> *M) {
45+
printUnrankedMemRefMetaData(std::cout, *M);
46+
int64_t rank = M->rank;
47+
void *ptr = M->descriptor;
48+
49+
switch (rank) {
50+
MEMREF_CASE(int32_t, 0);
51+
MEMREF_CASE(int32_t, 1);
52+
MEMREF_CASE(int32_t, 2);
53+
MEMREF_CASE(int32_t, 3);
54+
MEMREF_CASE(int32_t, 4);
55+
default:
56+
assert(0 && "Unsupported rank to print");
57+
}
58+
}
59+
4460
extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
4561
printUnrankedMemRefMetaData(std::cout, *M);
46-
int rank = M->rank;
62+
int64_t rank = M->rank;
4763
void *ptr = M->descriptor;
4864

4965
switch (rank) {
@@ -57,10 +73,13 @@ extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
5773
}
5874
}
5975

76+
extern "C" void print_memref_i32(int64_t rank, void *ptr) {
77+
UnrankedMemRefType<int32_t> descriptor = {rank, ptr};
78+
_mlir_ciface_print_memref_i32(&descriptor);
79+
}
80+
6081
extern "C" void print_memref_f32(int64_t rank, void *ptr) {
61-
UnrankedMemRefType<float> descriptor;
62-
descriptor.rank = rank;
63-
descriptor.descriptor = ptr;
82+
UnrankedMemRefType<float> descriptor = {rank, ptr};
6483
_mlir_ciface_print_memref_f32(&descriptor);
6584
}
6685

0 commit comments

Comments
 (0)