Skip to content

Commit 29f8628

Browse files
committed
[Constant] Add containsPoisonElement
This patch - Adds containsPoisonElement that checks existence of poison in constant vector elements, - Renames containsUndefElement to containsUndefOrPoisonElement to clarify its behavior & updates its uses properly With this patch, isGuaranteedNotToBeUndefOrPoison's tests w.r.t constant vectors are added because its analysis is improved. Thanks! Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D94053
1 parent 8444a24 commit 29f8628

File tree

8 files changed

+98
-20
lines changed

8 files changed

+98
-20
lines changed

llvm/include/llvm/IR/Constant.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,15 @@ class Constant : public User {
101101
/// lane, the constants still match.
102102
bool isElementWiseEqual(Value *Y) const;
103103

104-
/// Return true if this is a vector constant that includes any undefined
105-
/// elements. Since it is impossible to inspect a scalable vector element-
106-
/// wise at compile time, this function returns true only if the entire
107-
/// vector is undef
108-
bool containsUndefElement() const;
104+
/// Return true if this is a vector constant that includes any undef or
105+
/// poison elements. Since it is impossible to inspect a scalable vector
106+
/// element- wise at compile time, this function returns true only if the
107+
/// entire vector is undef or poison.
108+
bool containsUndefOrPoisonElement() const;
109+
110+
/// Return true if this is a vector constant that includes any poison
111+
/// elements.
112+
bool containsPoisonElement() const;
109113

110114
/// Return true if this is a fixed width vector constant that includes
111115
/// any constant expressions.

llvm/lib/Analysis/ValueTracking.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -4895,7 +4895,8 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V,
48954895
return true;
48964896

48974897
if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
4898-
return (PoisonOnly || !C->containsUndefElement()) &&
4898+
return (PoisonOnly ? !C->containsPoisonElement()
4899+
: !C->containsUndefOrPoisonElement()) &&
48994900
!C->containsConstantExpression();
49004901
}
49014902

@@ -5636,10 +5637,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
56365637
// elements because those can not be back-propagated for analysis.
56375638
Value *OutputZeroVal = nullptr;
56385639
if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
5639-
!cast<Constant>(TrueVal)->containsUndefElement())
5640+
!cast<Constant>(TrueVal)->containsUndefOrPoisonElement())
56405641
OutputZeroVal = TrueVal;
56415642
else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
5642-
!cast<Constant>(FalseVal)->containsUndefElement())
5643+
!cast<Constant>(FalseVal)->containsUndefOrPoisonElement())
56435644
OutputZeroVal = FalseVal;
56445645

56455646
if (OutputZeroVal) {

llvm/lib/IR/ConstantFold.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
811811
return true;
812812

813813
if (C->getType()->isVectorTy())
814-
return !C->containsUndefElement() && !C->containsConstantExpression();
814+
return !C->containsPoisonElement() && !C->containsConstantExpression();
815815

816816
// TODO: Recursively analyze aggregates or other constants.
817817
return false;

llvm/lib/IR/Constants.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -304,31 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const {
304304
return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
305305
}
306306

307-
bool Constant::containsUndefElement() const {
308-
if (auto *VTy = dyn_cast<VectorType>(getType())) {
309-
if (isa<UndefValue>(this))
307+
static bool
308+
containsUndefinedElement(const Constant *C,
309+
function_ref<bool(const Constant *)> HasFn) {
310+
if (auto *VTy = dyn_cast<VectorType>(C->getType())) {
311+
if (HasFn(C))
310312
return true;
311-
if (isa<ConstantAggregateZero>(this))
313+
if (isa<ConstantAggregateZero>(C))
312314
return false;
313-
if (isa<ScalableVectorType>(getType()))
315+
if (isa<ScalableVectorType>(C->getType()))
314316
return false;
315317

316318
for (unsigned i = 0, e = cast<FixedVectorType>(VTy)->getNumElements();
317319
i != e; ++i)
318-
if (isa<UndefValue>(getAggregateElement(i)))
320+
if (HasFn(C->getAggregateElement(i)))
319321
return true;
320322
}
321323

322324
return false;
323325
}
324326

327+
bool Constant::containsUndefOrPoisonElement() const {
328+
return containsUndefinedElement(
329+
this, [&](const auto *C) { return isa<UndefValue>(C); });
330+
}
331+
332+
bool Constant::containsPoisonElement() const {
333+
return containsUndefinedElement(
334+
this, [&](const auto *C) { return isa<PoisonValue>(C); });
335+
}
336+
325337
bool Constant::containsConstantExpression() const {
326338
if (auto *VTy = dyn_cast<FixedVectorType>(getType())) {
327339
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i)
328340
if (isa<ConstantExpr>(getAggregateElement(i)))
329341
return true;
330342
}
331-
332343
return false;
333344
}
334345

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -3370,7 +3370,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
33703370
Type *OpTy = M->getType();
33713371
auto *VecC = dyn_cast<Constant>(M);
33723372
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
3373-
if (OpVTy && VecC && VecC->containsUndefElement()) {
3373+
if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) {
33743374
Constant *SafeReplacementConstant = nullptr;
33753375
for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
33763376
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
@@ -5259,7 +5259,8 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
52595259
// It may not be safe to change a compare predicate in the presence of
52605260
// undefined elements, so replace those elements with the first safe constant
52615261
// that we found.
5262-
if (C->containsUndefElement()) {
5262+
// TODO: in case of poison, it is safe; let's replace undefs only.
5263+
if (C->containsUndefOrPoisonElement()) {
52635264
assert(SafeReplacementConstant && "Replacement constant not set");
52645265
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
52655266
}

llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
239239
// While this is normally not behind a use-check,
240240
// let's consider division to be special since it's costly.
241241
if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) {
242-
if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() &&
243-
Op1C->isNotOneValue()) {
242+
if (!Op1C->containsUndefOrPoisonElement() &&
243+
Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) {
244244
Value *BO =
245245
Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C),
246246
I->getName() + ".neg");

llvm/unittests/Analysis/ValueTrackingTest.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,30 @@ TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison) {
888888
EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
889889
EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true);
890890
EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
891+
892+
Type *Int32Ty = Type::getInt32Ty(Context);
893+
Constant *CU = UndefValue::get(Int32Ty);
894+
Constant *CP = PoisonValue::get(Int32Ty);
895+
Constant *C1 = ConstantInt::get(Int32Ty, 1);
896+
Constant *C2 = ConstantInt::get(Int32Ty, 2);
897+
898+
{
899+
Constant *V1 = ConstantVector::get({C1, C2});
900+
EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1));
901+
EXPECT_TRUE(isGuaranteedNotToBePoison(V1));
902+
}
903+
904+
{
905+
Constant *V2 = ConstantVector::get({C1, CU});
906+
EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2));
907+
EXPECT_TRUE(isGuaranteedNotToBePoison(V2));
908+
}
909+
910+
{
911+
Constant *V3 = ConstantVector::get({C1, CP});
912+
EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3));
913+
EXPECT_FALSE(isGuaranteedNotToBePoison(V3));
914+
}
891915
}
892916

893917
TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) {

llvm/unittests/IR/ConstantsTest.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,43 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) {
585585
Instruction::And, TheConstantExpr, TheConstant)->isNullValue());
586586
}
587587

588+
// Check that containsUndefOrPoisonElement and containsPoisonElement is working
589+
// great
590+
591+
TEST(ConstantsTest, containsUndefElemTest) {
592+
LLVMContext Context;
593+
594+
Type *Int32Ty = Type::getInt32Ty(Context);
595+
Constant *CU = UndefValue::get(Int32Ty);
596+
Constant *CP = PoisonValue::get(Int32Ty);
597+
Constant *C1 = ConstantInt::get(Int32Ty, 1);
598+
Constant *C2 = ConstantInt::get(Int32Ty, 2);
599+
600+
{
601+
Constant *V1 = ConstantVector::get({C1, C2});
602+
EXPECT_FALSE(V1->containsUndefOrPoisonElement());
603+
EXPECT_FALSE(V1->containsPoisonElement());
604+
}
605+
606+
{
607+
Constant *V2 = ConstantVector::get({C1, CU});
608+
EXPECT_TRUE(V2->containsUndefOrPoisonElement());
609+
EXPECT_FALSE(V2->containsPoisonElement());
610+
}
611+
612+
{
613+
Constant *V3 = ConstantVector::get({C1, CP});
614+
EXPECT_TRUE(V3->containsUndefOrPoisonElement());
615+
EXPECT_TRUE(V3->containsPoisonElement());
616+
}
617+
618+
{
619+
Constant *V4 = ConstantVector::get({CU, CP});
620+
EXPECT_TRUE(V4->containsUndefOrPoisonElement());
621+
EXPECT_TRUE(V4->containsPoisonElement());
622+
}
623+
}
624+
588625
// Check that undefined elements in vector constants are matched
589626
// correctly for both integer and floating-point types. Just don't
590627
// crash on vectors of pointers (could be handled?).

0 commit comments

Comments
 (0)