Skip to content

Commit cdb468c

Browse files
committed
AMDGPU: Basic folds for fmed3 intrinsic
Constant fold, canonicalize constants to RHS, reduce to minnum/maxnum when inputs are nan/undef. llvm-svn: 296409
1 parent 65da457 commit cdb468c

File tree

5 files changed

+295
-0
lines changed

5 files changed

+295
-0
lines changed

llvm/include/llvm/IR/IRBuilder.h

+16
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,22 @@ class IRBuilderBase {
560560
Type *ResultType,
561561
const Twine &Name = "");
562562

563+
/// Create a call to intrinsic \p ID with 2 operands which is mangled on the
564+
/// first type.
565+
CallInst *CreateBinaryIntrinsic(Intrinsic::ID ID,
566+
Value *LHS, Value *RHS,
567+
const Twine &Name = "");
568+
569+
/// Create call to the minnum intrinsic.
570+
CallInst *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
571+
return CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS, Name);
572+
}
573+
574+
/// Create call to the maxnum intrinsic.
575+
CallInst *CreateMaxNum(Value *LHS, Value *RHS, const Twine &Name = "") {
576+
return CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS, Name);
577+
}
578+
563579
private:
564580
/// \brief Create a call to a masked intrinsic with given Id.
565581
CallInst *CreateMaskedIntrinsic(Intrinsic::ID Id, ArrayRef<Value *> Ops,

llvm/include/llvm/IR/PatternMatch.h

+13
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,19 @@ inline match_combine_or<match_zero, match_neg_zero> m_AnyZero() {
157157
return m_CombineOr(m_Zero(), m_NegZero());
158158
}
159159

160+
struct match_nan {
161+
template <typename ITy> bool match(ITy *V) {
162+
if (const auto *C = dyn_cast<ConstantFP>(V)) {
163+
const APFloat &APF = C->getValueAPF();
164+
return APF.isNaN();
165+
}
166+
return false;
167+
}
168+
};
169+
170+
/// Match an arbitrary NaN constant. This includes quiet and signalling nans.
171+
inline match_nan m_NaN() { return match_nan(); }
172+
160173
struct apint_match {
161174
const APInt *&Res;
162175
apint_match(const APInt *&R) : Res(R) {}

llvm/lib/IR/IRBuilder.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,11 @@ CallInst *IRBuilderBase::CreateGCRelocate(Instruction *Statepoint,
482482
getInt32(DerivedOffset)};
483483
return createCallHelper(FnGCRelocate, Args, this, Name);
484484
}
485+
486+
CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID,
487+
Value *LHS, Value *RHS,
488+
const Twine &Name) {
489+
Module *M = BB->getParent()->getParent();
490+
Function *Fn = Intrinsic::getDeclaration(M, ID, { LHS->getType() });
491+
return createCallHelper(Fn, { LHS, RHS }, this, Name);
492+
}

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

+76
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,27 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
15331533
return true;
15341534
}
15351535

1536+
// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
1537+
//
1538+
// A single NaN input is folded to minnum, so we rely on that folding for
1539+
// handling NaNs.
1540+
static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
1541+
const APFloat &Src2) {
1542+
APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
1543+
1544+
APFloat::cmpResult Cmp0 = Max3.compare(Src0);
1545+
assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
1546+
if (Cmp0 == APFloat::cmpEqual)
1547+
return maxnum(Src1, Src2);
1548+
1549+
APFloat::cmpResult Cmp1 = Max3.compare(Src1);
1550+
assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
1551+
if (Cmp1 == APFloat::cmpEqual)
1552+
return maxnum(Src0, Src2);
1553+
1554+
return maxnum(Src0, Src1);
1555+
}
1556+
15361557
// Returns true iff the 2 intrinsics have the same operands, limiting the
15371558
// comparison to the first NumOperands.
15381559
static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E,
@@ -3331,6 +3352,61 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
33313352
return II;
33323353

33333354
break;
3355+
3356+
}
3357+
case Intrinsic::amdgcn_fmed3: {
3358+
// Note this does not preserve proper sNaN behavior if IEEE-mode is enabled
3359+
// for the shader.
3360+
3361+
Value *Src0 = II->getArgOperand(0);
3362+
Value *Src1 = II->getArgOperand(1);
3363+
Value *Src2 = II->getArgOperand(2);
3364+
3365+
bool Swap = false;
3366+
// Canonicalize constants to RHS operands.
3367+
//
3368+
// fmed3(c0, x, c1) -> fmed3(x, c0, c1)
3369+
if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
3370+
std::swap(Src0, Src1);
3371+
Swap = true;
3372+
}
3373+
3374+
if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
3375+
std::swap(Src1, Src2);
3376+
Swap = true;
3377+
}
3378+
3379+
if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
3380+
std::swap(Src0, Src1);
3381+
Swap = true;
3382+
}
3383+
3384+
if (Swap) {
3385+
II->setArgOperand(0, Src0);
3386+
II->setArgOperand(1, Src1);
3387+
II->setArgOperand(2, Src2);
3388+
return II;
3389+
}
3390+
3391+
if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
3392+
CallInst *NewCall = Builder->CreateMinNum(Src0, Src1);
3393+
NewCall->copyFastMathFlags(II);
3394+
NewCall->takeName(II);
3395+
return replaceInstUsesWith(*II, NewCall);
3396+
}
3397+
3398+
if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
3399+
if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
3400+
if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
3401+
APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
3402+
C2->getValueAPF());
3403+
return replaceInstUsesWith(*II,
3404+
ConstantFP::get(Builder->getContext(), Result));
3405+
}
3406+
}
3407+
}
3408+
3409+
break;
33343410
}
33353411
case Intrinsic::stackrestore: {
33363412
// If the save is right next to the restore, remove the restore. This can

llvm/test/Transforms/InstCombine/amdgcn-intrinsics.ll

+182
Original file line numberDiff line numberDiff line change
@@ -1025,3 +1025,185 @@ define void @exp_compr_disabled_inputs_to_undef(<2 x half> %xy, <2 x half> %zw)
10251025
call void @llvm.amdgcn.exp.compr.v2f16(i32 0, i32 15, <2 x half> %xy, <2 x half> %zw, i1 true, i1 false)
10261026
ret void
10271027
}
1028+
1029+
; --------------------------------------------------------------------
1030+
; llvm.amdgcn.fmed3
1031+
; --------------------------------------------------------------------
1032+
1033+
declare float @llvm.amdgcn.fmed3.f32(float, float, float) nounwind readnone
1034+
1035+
; CHECK-LABEL: @fmed3_f32(
1036+
; CHECK: %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float %z)
1037+
define float @fmed3_f32(float %x, float %y, float %z) {
1038+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float %z)
1039+
ret float %med3
1040+
}
1041+
1042+
; CHECK-LABEL: @fmed3_canonicalize_x_c0_c1_f32(
1043+
; CHECK: call float @llvm.amdgcn.fmed3.f32(float %x, float 0.000000e+00, float 1.000000e+00)
1044+
define float @fmed3_canonicalize_x_c0_c1_f32(float %x) {
1045+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0.0, float 1.0)
1046+
ret float %med3
1047+
}
1048+
1049+
; CHECK-LABEL: @fmed3_canonicalize_c0_x_c1_f32(
1050+
; CHECK: call float @llvm.amdgcn.fmed3.f32(float %x, float 0.000000e+00, float 1.000000e+00)
1051+
define float @fmed3_canonicalize_c0_x_c1_f32(float %x) {
1052+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0.0, float %x, float 1.0)
1053+
ret float %med3
1054+
}
1055+
1056+
; CHECK-LABEL: @fmed3_canonicalize_c0_c1_x_f32(
1057+
; CHECK: call float @llvm.amdgcn.fmed3.f32(float %x, float 0.000000e+00, float 1.000000e+00)
1058+
define float @fmed3_canonicalize_c0_c1_x_f32(float %x) {
1059+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0.0, float 1.0, float %x)
1060+
ret float %med3
1061+
}
1062+
1063+
; CHECK-LABEL: @fmed3_canonicalize_x_y_c_f32(
1064+
; CHECK: call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 1.000000e+00)
1065+
define float @fmed3_canonicalize_x_y_c_f32(float %x, float %y) {
1066+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 1.0)
1067+
ret float %med3
1068+
}
1069+
1070+
; CHECK-LABEL: @fmed3_canonicalize_x_c_y_f32(
1071+
; CHECK: %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 1.000000e+00)
1072+
define float @fmed3_canonicalize_x_c_y_f32(float %x, float %y) {
1073+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 1.0, float %y)
1074+
ret float %med3
1075+
}
1076+
1077+
; CHECK-LABEL: @fmed3_canonicalize_c_x_y_f32(
1078+
; CHECK: call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 1.000000e+00)
1079+
define float @fmed3_canonicalize_c_x_y_f32(float %x, float %y) {
1080+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 1.0, float %x, float %y)
1081+
ret float %med3
1082+
}
1083+
1084+
; CHECK-LABEL: @fmed3_undef_x_y_f32(
1085+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1086+
define float @fmed3_undef_x_y_f32(float %x, float %y) {
1087+
%med3 = call float @llvm.amdgcn.fmed3.f32(float undef, float %x, float %y)
1088+
ret float %med3
1089+
}
1090+
1091+
; CHECK-LABEL: @fmed3_fmf_undef_x_y_f32(
1092+
; CHECK: call nnan float @llvm.minnum.f32(float %x, float %y)
1093+
define float @fmed3_fmf_undef_x_y_f32(float %x, float %y) {
1094+
%med3 = call nnan float @llvm.amdgcn.fmed3.f32(float undef, float %x, float %y)
1095+
ret float %med3
1096+
}
1097+
1098+
; CHECK-LABEL: @fmed3_x_undef_y_f32(
1099+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1100+
define float @fmed3_x_undef_y_f32(float %x, float %y) {
1101+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float undef, float %y)
1102+
ret float %med3
1103+
}
1104+
1105+
; CHECK-LABEL: @fmed3_x_y_undef_f32(
1106+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1107+
define float @fmed3_x_y_undef_f32(float %x, float %y) {
1108+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float undef)
1109+
ret float %med3
1110+
}
1111+
1112+
; CHECK-LABEL: @fmed3_qnan0_x_y_f32(
1113+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1114+
define float @fmed3_qnan0_x_y_f32(float %x, float %y) {
1115+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8000000000000, float %x, float %y)
1116+
ret float %med3
1117+
}
1118+
1119+
; CHECK-LABEL: @fmed3_x_qnan0_y_f32(
1120+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1121+
define float @fmed3_x_qnan0_y_f32(float %x, float %y) {
1122+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0x7FF8000000000000, float %y)
1123+
ret float %med3
1124+
}
1125+
1126+
; CHECK-LABEL: @fmed3_x_y_qnan0_f32(
1127+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1128+
define float @fmed3_x_y_qnan0_f32(float %x, float %y) {
1129+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0x7FF8000000000000)
1130+
ret float %med3
1131+
}
1132+
1133+
; CHECK-LABEL: @fmed3_qnan1_x_y_f32(
1134+
; CHECK: call float @llvm.minnum.f32(float %x, float %y)
1135+
define float @fmed3_qnan1_x_y_f32(float %x, float %y) {
1136+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8000100000000, float %x, float %y)
1137+
ret float %med3
1138+
}
1139+
1140+
; This can return any of the qnans.
1141+
; CHECK-LABEL: @fmed3_qnan0_qnan1_qnan2_f32(
1142+
; CHECK: ret float 0x7FF8002000000000
1143+
define float @fmed3_qnan0_qnan1_qnan2_f32(float %x, float %y) {
1144+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8000100000000, float 0x7FF8002000000000, float 0x7FF8030000000000)
1145+
ret float %med3
1146+
}
1147+
1148+
; CHECK-LABEL: @fmed3_constant_src0_0_f32(
1149+
; CHECK: ret float 5.000000e-01
1150+
define float @fmed3_constant_src0_0_f32(float %x, float %y) {
1151+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0.5, float -1.0, float 4.0)
1152+
ret float %med3
1153+
}
1154+
1155+
; CHECK-LABEL: @fmed3_constant_src0_1_f32(
1156+
; CHECK: ret float 5.000000e-01
1157+
define float @fmed3_constant_src0_1_f32(float %x, float %y) {
1158+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0.5, float 4.0, float -1.0)
1159+
ret float %med3
1160+
}
1161+
1162+
; CHECK-LABEL: @fmed3_constant_src1_0_f32(
1163+
; CHECK: ret float 5.000000e-01
1164+
define float @fmed3_constant_src1_0_f32(float %x, float %y) {
1165+
%med3 = call float @llvm.amdgcn.fmed3.f32(float -1.0, float 0.5, float 4.0)
1166+
ret float %med3
1167+
}
1168+
1169+
; CHECK-LABEL: @fmed3_constant_src1_1_f32(
1170+
; CHECK: ret float 5.000000e-01
1171+
define float @fmed3_constant_src1_1_f32(float %x, float %y) {
1172+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 4.0, float 0.5, float -1.0)
1173+
ret float %med3
1174+
}
1175+
1176+
; CHECK-LABEL: @fmed3_constant_src2_0_f32(
1177+
; CHECK: ret float 5.000000e-01
1178+
define float @fmed3_constant_src2_0_f32(float %x, float %y) {
1179+
%med3 = call float @llvm.amdgcn.fmed3.f32(float -1.0, float 4.0, float 0.5)
1180+
ret float %med3
1181+
}
1182+
1183+
; CHECK-LABEL: @fmed3_constant_src2_1_f32(
1184+
; CHECK: ret float 5.000000e-01
1185+
define float @fmed3_constant_src2_1_f32(float %x, float %y) {
1186+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 4.0, float -1.0, float 0.5)
1187+
ret float %med3
1188+
}
1189+
1190+
; CHECK-LABEL: @fmed3_x_qnan0_qnan1_f32(
1191+
; CHECK: ret float %x
1192+
define float @fmed3_x_qnan0_qnan1_f32(float %x) {
1193+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0x7FF8001000000000, float 0x7FF8002000000000)
1194+
ret float %med3
1195+
}
1196+
1197+
; CHECK-LABEL: @fmed3_qnan0_x_qnan1_f32(
1198+
; CHECK: ret float %x
1199+
define float @fmed3_qnan0_x_qnan1_f32(float %x) {
1200+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8001000000000, float %x, float 0x7FF8002000000000)
1201+
ret float %med3
1202+
}
1203+
1204+
; CHECK-LABEL: @fmed3_qnan0_qnan1_x_f32(
1205+
; CHECK: ret float %x
1206+
define float @fmed3_qnan0_qnan1_x_f32(float %x) {
1207+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8001000000000, float 0x7FF8002000000000, float %x)
1208+
ret float %med3
1209+
}

0 commit comments

Comments
 (0)