Skip to content

Commit 6e63dfd

Browse files
committed
[RISCV] Custom lowering of FLT_ROUNDS_
Differential Revision: https://reviews.llvm.org/D90854
1 parent d40a19c commit 6e63dfd

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
372372
setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
373373
}
374374

375+
if (Subtarget.hasStdExtF()) {
376+
setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom);
377+
}
378+
375379
setOperationAction(ISD::GlobalAddress, XLenVT, Custom);
376380
setOperationAction(ISD::BlockAddress, XLenVT, Custom);
377381
setOperationAction(ISD::ConstantPool, XLenVT, Custom);
@@ -2161,6 +2165,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
21612165
return lowerMGATHER(Op, DAG);
21622166
case ISD::MSCATTER:
21632167
return lowerMSCATTER(Op, DAG);
2168+
case ISD::FLT_ROUNDS_:
2169+
return lowerGET_ROUNDING(Op, DAG);
21642170
}
21652171
}
21662172

@@ -4107,6 +4113,37 @@ SDValue RISCVTargetLowering::lowerMSCATTER(SDValue Op,
41074113
MSN->getMemoryVT(), MSN->getMemOperand());
41084114
}
41094115

4116+
SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op,
4117+
SelectionDAG &DAG) const {
4118+
const MVT XLenVT = Subtarget.getXLenVT();
4119+
SDLoc DL(Op);
4120+
SDValue Chain = Op->getOperand(0);
4121+
SDValue SysRegNo = DAG.getConstant(
4122+
RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
4123+
SDVTList VTs = DAG.getVTList(XLenVT, MVT::Other);
4124+
SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo);
4125+
4126+
// Encoding used for rounding mode in RISCV differs from that used in
4127+
// FLT_ROUNDS. To convert it the RISCV rounding mode is used as an index in a
4128+
// table, which consists of a sequence of 4-bit fields, each representing
4129+
// corresponding FLT_ROUNDS mode.
4130+
static const int Table =
4131+
(int(RoundingMode::NearestTiesToEven) << 4 * RISCVFPRndMode::RNE) |
4132+
(int(RoundingMode::TowardZero) << 4 * RISCVFPRndMode::RTZ) |
4133+
(int(RoundingMode::TowardNegative) << 4 * RISCVFPRndMode::RDN) |
4134+
(int(RoundingMode::TowardPositive) << 4 * RISCVFPRndMode::RUP) |
4135+
(int(RoundingMode::NearestTiesToAway) << 4 * RISCVFPRndMode::RMM);
4136+
4137+
SDValue Shift =
4138+
DAG.getNode(ISD::SHL, DL, XLenVT, RM, DAG.getConstant(2, DL, XLenVT));
4139+
SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
4140+
DAG.getConstant(Table, DL, XLenVT), Shift);
4141+
SDValue Masked = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
4142+
DAG.getConstant(7, DL, XLenVT));
4143+
4144+
return DAG.getMergeValues({Masked, Chain}, DL);
4145+
}
4146+
41104147
// Returns the opcode of the target-specific SDNode that implements the 32-bit
41114148
// form of the given Opcode.
41124149
static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
@@ -4584,6 +4621,13 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
45844621
if (SDValue V = lowerVECREDUCE(SDValue(N, 0), DAG))
45854622
Results.push_back(V);
45864623
break;
4624+
case ISD::FLT_ROUNDS_: {
4625+
SDVTList VTs = DAG.getVTList(Subtarget.getXLenVT(), MVT::Other);
4626+
SDValue Res = DAG.getNode(ISD::FLT_ROUNDS_, DL, VTs, N->getOperand(0));
4627+
Results.push_back(Res.getValue(0));
4628+
Results.push_back(Res.getValue(1));
4629+
break;
4630+
}
45874631
}
45884632
}
45894633

llvm/lib/Target/RISCV/RISCVISelLowering.h

+1
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ class RISCVTargetLowering : public TargetLowering {
533533
bool HasMask = true) const;
534534
SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
535535
unsigned ExtendOpc) const;
536+
SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
536537

537538
bool isEligibleForTailCallOptimization(
538539
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,

llvm/test/CodeGen/RISCV/fpenv.ll

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV32IF %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV64IF %s
4+
5+
define i32 @func_01() {
6+
; RV32IF-LABEL: func_01:
7+
; RV32IF: # %bb.0:
8+
; RV32IF-NEXT: frrm a0
9+
; RV32IF-NEXT: slli a0, a0, 2
10+
; RV32IF-NEXT: lui a1, 66
11+
; RV32IF-NEXT: addi a1, a1, 769
12+
; RV32IF-NEXT: srl a0, a1, a0
13+
; RV32IF-NEXT: andi a0, a0, 7
14+
; RV32IF-NEXT: ret
15+
;
16+
; RV64IF-LABEL: func_01:
17+
; RV64IF: # %bb.0:
18+
; RV64IF-NEXT: frrm a0
19+
; RV64IF-NEXT: slli a0, a0, 2
20+
; RV64IF-NEXT: lui a1, 66
21+
; RV64IF-NEXT: addiw a1, a1, 769
22+
; RV64IF-NEXT: srl a0, a1, a0
23+
; RV64IF-NEXT: andi a0, a0, 7
24+
; RV64IF-NEXT: ret
25+
%rm = call i32 @llvm.flt.rounds()
26+
ret i32 %rm
27+
}
28+
29+
declare i32 @llvm.flt.rounds()

0 commit comments

Comments
 (0)