@@ -123,18 +123,51 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
123
123
return isFloatingPoint ? getFactory<LLVM::FMulOp>()
124
124
: getFactory<LLVM::MulOp>();
125
125
}
126
+ if (opName == " and" ) {
127
+ return getFactory<LLVM::AndOp>();
128
+ }
129
+ if (opName == " or" ) {
130
+ return getFactory<LLVM::OrOp>();
131
+ }
132
+ if (opName == " xor" ) {
133
+ return getFactory<LLVM::XOrOp>();
134
+ }
135
+ if (opName == " max" ) {
136
+ return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
137
+ LLVM::FCmpPredicate::ugt>()
138
+ : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
139
+ LLVM::ICmpPredicate::ugt>();
140
+ }
141
+ if (opName == " min" ) {
142
+ return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
143
+ LLVM::FCmpPredicate::ult>()
144
+ : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
145
+ LLVM::ICmpPredicate::ult>();
146
+ }
126
147
127
148
return AccumulatorFactory ();
128
149
}
129
150
130
151
// / Returns an accumulator factory that creates an op of type T.
131
- template <typename T> AccumulatorFactory getFactory () const {
152
+ template <typename T>
153
+ AccumulatorFactory getFactory () const {
132
154
return [](Location loc, Value lhs, Value rhs,
133
155
ConversionPatternRewriter &rewriter) {
134
156
return rewriter.create <T>(loc, lhs.getType (), lhs, rhs);
135
157
};
136
158
}
137
159
160
+ // / Returns an accumulator for comparaison such as min, max. T is the type
161
+ // / of the compare op.
162
+ template <typename T, typename PredicateEnum, PredicateEnum predicate>
163
+ AccumulatorFactory getCmpFactory () const {
164
+ return [](Location loc, Value lhs, Value rhs,
165
+ ConversionPatternRewriter &rewriter) {
166
+ Value cmp = rewriter.create <T>(loc, predicate, lhs, rhs);
167
+ return rewriter.create <LLVM::SelectOp>(loc, cmp, lhs, rhs);
168
+ };
169
+ }
170
+
138
171
// / Creates an all_reduce across the block.
139
172
// /
140
173
// / First reduce the elements within a warp. The first thread of each warp
@@ -705,9 +738,9 @@ void mlir::populateGpuToNVVMConversionPatterns(
705
738
GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
706
739
GPUReturnOpLowering>(converter);
707
740
patterns.insert <OpToFuncCallLowering<AbsFOp>>(converter, " __nv_fabsf" ,
708
- " __nv_fabs" );
741
+ " __nv_fabs" );
709
742
patterns.insert <OpToFuncCallLowering<CeilFOp>>(converter, " __nv_ceilf" ,
710
- " __nv_ceil" );
743
+ " __nv_ceil" );
711
744
patterns.insert <OpToFuncCallLowering<CosOp>>(converter, " __nv_cosf" ,
712
745
" __nv_cos" );
713
746
patterns.insert <OpToFuncCallLowering<ExpOp>>(converter, " __nv_expf" ,
0 commit comments