Skip to content

Commit 1f49b71

Browse files
[SVE][CodeGen] Enable reciprocal estimates for scalable fdiv/fsqrt
This patch enables the use of reciprocal estimates for SVE when both the -Ofast and -mrecip flags are used. Reviewed By: david-arm, paulwalker-arm Differential Revision: https://reviews.llvm.org/D111657
1 parent 5fd55b1 commit 1f49b71

File tree

3 files changed

+201
-8
lines changed

3 files changed

+201
-8
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4130,6 +4130,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
41304130
case Intrinsic::aarch64_sve_frecpx:
41314131
return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
41324132
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
4133+
case Intrinsic::aarch64_sve_frecpe_x:
4134+
return DAG.getNode(AArch64ISD::FRECPE, dl, Op.getValueType(),
4135+
Op.getOperand(1));
4136+
case Intrinsic::aarch64_sve_frecps_x:
4137+
return DAG.getNode(AArch64ISD::FRECPS, dl, Op.getValueType(),
4138+
Op.getOperand(1), Op.getOperand(2));
4139+
case Intrinsic::aarch64_sve_frsqrte_x:
4140+
return DAG.getNode(AArch64ISD::FRSQRTE, dl, Op.getValueType(),
4141+
Op.getOperand(1));
4142+
case Intrinsic::aarch64_sve_frsqrts_x:
4143+
return DAG.getNode(AArch64ISD::FRSQRTS, dl, Op.getValueType(),
4144+
Op.getOperand(1), Op.getOperand(2));
41334145
case Intrinsic::aarch64_sve_fabs:
41344146
return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
41354147
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
@@ -8235,10 +8247,12 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
82358247
SDValue Operand, SelectionDAG &DAG,
82368248
int &ExtraSteps) {
82378249
EVT VT = Operand.getValueType();
8238-
if (ST->hasNEON() &&
8239-
(VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
8240-
VT == MVT::f32 || VT == MVT::v1f32 ||
8241-
VT == MVT::v2f32 || VT == MVT::v4f32)) {
8250+
if ((ST->hasNEON() &&
8251+
(VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
8252+
VT == MVT::f32 || VT == MVT::v1f32 || VT == MVT::v2f32 ||
8253+
VT == MVT::v4f32)) ||
8254+
(ST->hasSVE() &&
8255+
(VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
82428256
if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified)
82438257
// For the reciprocal estimates, convergence is quadratic, so the number
82448258
// of digits is doubled after each iteration. In ARMv8, the accuracy of

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ let Predicates = [HasSVEorStreamingSVE] in {
402402
defm SMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64smin_p>;
403403
defm UMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64umin_p>;
404404

405-
defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe", int_aarch64_sve_frecpe_x>;
406-
defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", int_aarch64_sve_frsqrte_x>;
405+
defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe", AArch64frecpe>;
406+
defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", AArch64frsqrte>;
407407

408408
defm FADD_ZPmI : sve_fp_2op_i_p_zds<0b000, "fadd", "FADD_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fadd>;
409409
defm FSUB_ZPmI : sve_fp_2op_i_p_zds<0b001, "fsub", "FSUB_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fsub>;
@@ -484,8 +484,8 @@ let Predicates = [HasSVE] in {
484484
} // End HasSVE
485485

486486
let Predicates = [HasSVEorStreamingSVE] in {
487-
defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", int_aarch64_sve_frecps_x>;
488-
defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", int_aarch64_sve_frsqrts_x>;
487+
defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", AArch64frecps>;
488+
defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", AArch64frsqrts>;
489489
} // End HasSVEorStreamingSVE
490490

491491
let Predicates = [HasSVE] in {
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
3+
4+
; FDIV
5+
6+
define <vscale x 8 x half> @fdiv_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) {
7+
; CHECK-LABEL: fdiv_8f16:
8+
; CHECK: // %bb.0:
9+
; CHECK-NEXT: ptrue p0.h
10+
; CHECK-NEXT: fdiv z0.h, p0/m, z0.h, z1.h
11+
; CHECK-NEXT: ret
12+
%fdiv = fdiv fast <vscale x 8 x half> %a, %b
13+
ret <vscale x 8 x half> %fdiv
14+
}
15+
16+
define <vscale x 8 x half> @fdiv_recip_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) #0 {
17+
; CHECK-LABEL: fdiv_recip_8f16:
18+
; CHECK: // %bb.0:
19+
; CHECK-NEXT: frecpe z2.h, z1.h
20+
; CHECK-NEXT: frecps z3.h, z1.h, z2.h
21+
; CHECK-NEXT: fmul z2.h, z2.h, z3.h
22+
; CHECK-NEXT: frecps z1.h, z1.h, z2.h
23+
; CHECK-NEXT: fmul z1.h, z2.h, z1.h
24+
; CHECK-NEXT: fmul z0.h, z1.h, z0.h
25+
; CHECK-NEXT: ret
26+
%fdiv = fdiv fast <vscale x 8 x half> %a, %b
27+
ret <vscale x 8 x half> %fdiv
28+
}
29+
30+
define <vscale x 4 x float> @fdiv_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
31+
; CHECK-LABEL: fdiv_4f32:
32+
; CHECK: // %bb.0:
33+
; CHECK-NEXT: ptrue p0.s
34+
; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s
35+
; CHECK-NEXT: ret
36+
%fdiv = fdiv fast <vscale x 4 x float> %a, %b
37+
ret <vscale x 4 x float> %fdiv
38+
}
39+
40+
define <vscale x 4 x float> @fdiv_recip_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
41+
; CHECK-LABEL: fdiv_recip_4f32:
42+
; CHECK: // %bb.0:
43+
; CHECK-NEXT: frecpe z2.s, z1.s
44+
; CHECK-NEXT: frecps z3.s, z1.s, z2.s
45+
; CHECK-NEXT: fmul z2.s, z2.s, z3.s
46+
; CHECK-NEXT: frecps z1.s, z1.s, z2.s
47+
; CHECK-NEXT: fmul z1.s, z2.s, z1.s
48+
; CHECK-NEXT: fmul z0.s, z1.s, z0.s
49+
; CHECK-NEXT: ret
50+
%fdiv = fdiv fast <vscale x 4 x float> %a, %b
51+
ret <vscale x 4 x float> %fdiv
52+
}
53+
54+
define <vscale x 2 x double> @fdiv_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
55+
; CHECK-LABEL: fdiv_2f64:
56+
; CHECK: // %bb.0:
57+
; CHECK-NEXT: ptrue p0.d
58+
; CHECK-NEXT: fdiv z0.d, p0/m, z0.d, z1.d
59+
; CHECK-NEXT: ret
60+
%fdiv = fdiv fast <vscale x 2 x double> %a, %b
61+
ret <vscale x 2 x double> %fdiv
62+
}
63+
64+
define <vscale x 2 x double> @fdiv_recip_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) #0 {
65+
; CHECK-LABEL: fdiv_recip_2f64:
66+
; CHECK: // %bb.0:
67+
; CHECK-NEXT: frecpe z2.d, z1.d
68+
; CHECK-NEXT: frecps z3.d, z1.d, z2.d
69+
; CHECK-NEXT: fmul z2.d, z2.d, z3.d
70+
; CHECK-NEXT: frecps z3.d, z1.d, z2.d
71+
; CHECK-NEXT: fmul z2.d, z2.d, z3.d
72+
; CHECK-NEXT: frecps z1.d, z1.d, z2.d
73+
; CHECK-NEXT: fmul z1.d, z2.d, z1.d
74+
; CHECK-NEXT: fmul z0.d, z1.d, z0.d
75+
; CHECK-NEXT: ret
76+
%fdiv = fdiv fast <vscale x 2 x double> %a, %b
77+
ret <vscale x 2 x double> %fdiv
78+
}
79+
80+
; FSQRT
81+
82+
define <vscale x 8 x half> @fsqrt_8f16(<vscale x 8 x half> %a) {
83+
; CHECK-LABEL: fsqrt_8f16:
84+
; CHECK: // %bb.0:
85+
; CHECK-NEXT: ptrue p0.h
86+
; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h
87+
; CHECK-NEXT: ret
88+
%fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
89+
ret <vscale x 8 x half> %fsqrt
90+
}
91+
92+
define <vscale x 8 x half> @fsqrt_recip_8f16(<vscale x 8 x half> %a) #0 {
93+
; CHECK-LABEL: fsqrt_recip_8f16:
94+
; CHECK: // %bb.0:
95+
; CHECK-NEXT: frsqrte z1.h, z0.h
96+
; CHECK-NEXT: ptrue p0.h
97+
; CHECK-NEXT: fmul z2.h, z1.h, z1.h
98+
; CHECK-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0
99+
; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h
100+
; CHECK-NEXT: fmul z1.h, z1.h, z2.h
101+
; CHECK-NEXT: fmul z2.h, z1.h, z1.h
102+
; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h
103+
; CHECK-NEXT: fmul z1.h, z1.h, z2.h
104+
; CHECK-NEXT: fmul z1.h, z0.h, z1.h
105+
; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h
106+
; CHECK-NEXT: ret
107+
%fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
108+
ret <vscale x 8 x half> %fsqrt
109+
}
110+
111+
define <vscale x 4 x float> @fsqrt_4f32(<vscale x 4 x float> %a) {
112+
; CHECK-LABEL: fsqrt_4f32:
113+
; CHECK: // %bb.0:
114+
; CHECK-NEXT: ptrue p0.s
115+
; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s
116+
; CHECK-NEXT: ret
117+
%fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
118+
ret <vscale x 4 x float> %fsqrt
119+
}
120+
121+
define <vscale x 4 x float> @fsqrt_recip_4f32(<vscale x 4 x float> %a) #0 {
122+
; CHECK-LABEL: fsqrt_recip_4f32:
123+
; CHECK: // %bb.0:
124+
; CHECK-NEXT: frsqrte z1.s, z0.s
125+
; CHECK-NEXT: ptrue p0.s
126+
; CHECK-NEXT: fmul z2.s, z1.s, z1.s
127+
; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0
128+
; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s
129+
; CHECK-NEXT: fmul z1.s, z1.s, z2.s
130+
; CHECK-NEXT: fmul z2.s, z1.s, z1.s
131+
; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s
132+
; CHECK-NEXT: fmul z1.s, z1.s, z2.s
133+
; CHECK-NEXT: fmul z1.s, z0.s, z1.s
134+
; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s
135+
; CHECK-NEXT: ret
136+
%fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
137+
ret <vscale x 4 x float> %fsqrt
138+
}
139+
140+
define <vscale x 2 x double> @fsqrt_2f64(<vscale x 2 x double> %a) {
141+
; CHECK-LABEL: fsqrt_2f64:
142+
; CHECK: // %bb.0:
143+
; CHECK-NEXT: ptrue p0.d
144+
; CHECK-NEXT: fsqrt z0.d, p0/m, z0.d
145+
; CHECK-NEXT: ret
146+
%fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
147+
ret <vscale x 2 x double> %fsqrt
148+
}
149+
150+
define <vscale x 2 x double> @fsqrt_recip_2f64(<vscale x 2 x double> %a) #0 {
151+
; CHECK-LABEL: fsqrt_recip_2f64:
152+
; CHECK: // %bb.0:
153+
; CHECK-NEXT: frsqrte z1.d, z0.d
154+
; CHECK-NEXT: ptrue p0.d
155+
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
156+
; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, #0.0
157+
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
158+
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
159+
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
160+
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
161+
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
162+
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
163+
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
164+
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
165+
; CHECK-NEXT: fmul z1.d, z0.d, z1.d
166+
; CHECK-NEXT: sel z0.d, p0, z0.d, z1.d
167+
; CHECK-NEXT: ret
168+
%fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
169+
ret <vscale x 2 x double> %fsqrt
170+
}
171+
172+
declare <vscale x 2 x half> @llvm.sqrt.nxv2f16(<vscale x 2 x half>)
173+
declare <vscale x 4 x half> @llvm.sqrt.nxv4f16(<vscale x 4 x half>)
174+
declare <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half>)
175+
declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
176+
declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
177+
declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)
178+
179+
attributes #0 = { "reciprocal-estimates"="all" }

0 commit comments

Comments
 (0)