@@ -118,6 +118,8 @@ class Vectorizer {
118
118
bool run ();
119
119
120
120
private:
121
+ GetElementPtrInst *getSourceGEP (Value *Src) const ;
122
+
121
123
unsigned getPointerAddressSpace (Value *I);
122
124
123
125
unsigned getAlignment (LoadInst *LI) const {
@@ -137,8 +139,6 @@ class Vectorizer {
137
139
}
138
140
139
141
bool isConsecutiveAccess (Value *A, Value *B);
140
- bool areConsecutivePointers (Value *PtrA, Value *PtrB, APInt Size);
141
- bool lookThroughComplexAddresses (Value *PtrA, Value *PtrB, APInt PtrDelta);
142
142
143
143
// / After vectorization, reorder the instructions that I depends on
144
144
// / (the instructions defining its operands), to ensure they dominate I.
@@ -277,6 +277,21 @@ unsigned Vectorizer::getPointerAddressSpace(Value *I) {
277
277
return -1 ;
278
278
}
279
279
280
+ GetElementPtrInst *Vectorizer::getSourceGEP (Value *Src) const {
281
+ // First strip pointer bitcasts. Make sure pointee size is the same with
282
+ // and without casts.
283
+ // TODO: a stride set by the add instruction below can match the difference
284
+ // in pointee type size here. Currently it will not be vectorized.
285
+ Value *SrcPtr = getLoadStorePointerOperand (Src);
286
+ Value *SrcBase = SrcPtr->stripPointerCasts ();
287
+ Type *SrcPtrType = SrcPtr->getType ()->getPointerElementType ();
288
+ Type *SrcBaseType = SrcBase->getType ()->getPointerElementType ();
289
+ if (SrcPtrType->isSized () && SrcBaseType->isSized () &&
290
+ DL.getTypeStoreSize (SrcPtrType) == DL.getTypeStoreSize (SrcBaseType))
291
+ SrcPtr = SrcBase;
292
+ return dyn_cast<GetElementPtrInst>(SrcPtr);
293
+ }
294
+
280
295
// FIXME: Merge with llvm::isConsecutiveAccess
281
296
bool Vectorizer::isConsecutiveAccess (Value *A, Value *B) {
282
297
Value *PtrA = getLoadStorePointerOperand (A);
@@ -289,6 +304,7 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
289
304
return false ;
290
305
291
306
// Make sure that A and B are different pointers of the same size type.
307
+ unsigned PtrBitWidth = DL.getPointerSizeInBits (ASA);
292
308
Type *PtrATy = PtrA->getType ()->getPointerElementType ();
293
309
Type *PtrBTy = PtrB->getType ()->getPointerElementType ();
294
310
if (PtrA == PtrB ||
@@ -298,16 +314,10 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
298
314
DL.getTypeStoreSize (PtrBTy->getScalarType ()))
299
315
return false ;
300
316
301
- unsigned PtrBitWidth = DL.getPointerSizeInBits (ASA);
302
317
APInt Size (PtrBitWidth, DL.getTypeStoreSize (PtrATy));
303
318
304
- return areConsecutivePointers (PtrA, PtrB, Size);
305
- }
306
-
307
- bool Vectorizer::areConsecutivePointers (Value *PtrA, Value *PtrB, APInt Size) {
308
- unsigned PtrBitWidth = DL.getPointerTypeSizeInBits (PtrA->getType ());
309
- APInt OffsetA (PtrBitWidth, 0 );
310
- APInt OffsetB (PtrBitWidth, 0 );
319
+ unsigned IdxWidth = DL.getIndexSizeInBits (ASA);
320
+ APInt OffsetA (IdxWidth, 0 ), OffsetB (IdxWidth, 0 );
311
321
PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets (DL, OffsetA);
312
322
PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets (DL, OffsetB);
313
323
@@ -341,94 +351,68 @@ bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size) {
341
351
// Sometimes even this doesn't work, because SCEV can't always see through
342
352
// patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
343
353
// things the hard way.
344
- return lookThroughComplexAddresses (PtrA, PtrB, BaseDelta);
345
- }
346
-
347
- bool Vectorizer::lookThroughComplexAddresses (Value *PtrA, Value *PtrB,
348
- APInt PtrDelta) {
349
- auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
350
- auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
351
- if (!GEPA || !GEPB)
352
- return false ;
353
354
354
355
// Look through GEPs after checking they're the same except for the last
355
356
// index.
356
- if (GEPA->getNumOperands () != GEPB->getNumOperands () ||
357
- GEPA->getPointerOperand () != GEPB->getPointerOperand ())
357
+ GetElementPtrInst *GEPA = getSourceGEP (A);
358
+ GetElementPtrInst *GEPB = getSourceGEP (B);
359
+ if (!GEPA || !GEPB || GEPA->getNumOperands () != GEPB->getNumOperands ())
358
360
return false ;
359
- gep_type_iterator GTIA = gep_type_begin (GEPA);
360
- gep_type_iterator GTIB = gep_type_begin (GEPB);
361
- for (unsigned I = 0 , E = GEPA->getNumIndices () - 1 ; I < E; ++I) {
362
- if (GTIA.getOperand () != GTIB.getOperand ())
361
+ unsigned FinalIndex = GEPA->getNumOperands () - 1 ;
362
+ for (unsigned i = 0 ; i < FinalIndex; i++)
363
+ if (GEPA->getOperand (i) != GEPB->getOperand (i))
363
364
return false ;
364
- ++GTIA;
365
- ++GTIB;
366
- }
367
365
368
- Instruction *OpA = dyn_cast<Instruction>(GTIA. getOperand ());
369
- Instruction *OpB = dyn_cast<Instruction>(GTIB. getOperand ());
366
+ Instruction *OpA = dyn_cast<Instruction>(GEPA-> getOperand (FinalIndex ));
367
+ Instruction *OpB = dyn_cast<Instruction>(GEPB-> getOperand (FinalIndex ));
370
368
if (!OpA || !OpB || OpA->getOpcode () != OpB->getOpcode () ||
371
369
OpA->getType () != OpB->getType ())
372
370
return false ;
373
371
374
- if (PtrDelta.isNegative ()) {
375
- if (PtrDelta.isMinSignedValue ())
376
- return false ;
377
- PtrDelta.negate ();
378
- std::swap (OpA, OpB);
379
- }
380
- uint64_t Stride = DL.getTypeAllocSize (GTIA.getIndexedType ());
381
- if (PtrDelta.urem (Stride) != 0 )
382
- return false ;
383
- unsigned IdxBitWidth = OpA->getType ()->getScalarSizeInBits ();
384
- APInt IdxDiff = PtrDelta.udiv (Stride).zextOrSelf (IdxBitWidth);
385
-
386
372
// Only look through a ZExt/SExt.
387
373
if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
388
374
return false ;
389
375
390
376
bool Signed = isa<SExtInst>(OpA);
391
377
392
- // At this point A could be a function parameter, i.e. not an instruction
393
- Value *ValA = OpA->getOperand (0 );
378
+ OpA = dyn_cast<Instruction>(OpA->getOperand (0 ));
394
379
OpB = dyn_cast<Instruction>(OpB->getOperand (0 ));
395
- if (!OpB || ValA ->getType () != OpB->getType ())
380
+ if (!OpA || ! OpB || OpA ->getType () != OpB->getType ())
396
381
return false ;
397
382
398
- // Now we need to prove that adding IdxDiff to ValA won't overflow.
383
+ // Now we need to prove that adding 1 to OpA won't overflow.
399
384
bool Safe = false ;
400
- // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
401
- // ValA, we're okay.
385
+ // First attempt: if OpB is an add with NSW/NUW, and OpB is 1 added to OpA,
386
+ // we're okay.
402
387
if (OpB->getOpcode () == Instruction::Add &&
403
388
isa<ConstantInt>(OpB->getOperand (1 )) &&
404
- IdxDiff. sle ( cast<ConstantInt>(OpB->getOperand (1 ))->getSExtValue ()) ) {
389
+ cast<ConstantInt>(OpB->getOperand (1 ))->getSExtValue () > 0 ) {
405
390
if (Signed)
406
391
Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap ();
407
392
else
408
393
Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap ();
409
394
}
410
395
411
- unsigned BitWidth = ValA ->getType ()->getScalarSizeInBits ();
396
+ unsigned BitWidth = OpA ->getType ()->getScalarSizeInBits ();
412
397
413
398
// Second attempt:
414
- // If all set bits of IdxDiff or any higher order bit other than the sign bit
415
- // are known to be zero in ValA, we can add Diff to it while guaranteeing no
416
- // overflow of any sort.
399
+ // If any bits are known to be zero other than the sign bit in OpA, we can
400
+ // add 1 to it while guaranteeing no overflow of any sort.
417
401
if (!Safe) {
418
- OpA = dyn_cast<Instruction>(ValA);
419
- if (!OpA)
420
- return false ;
421
402
KnownBits Known (BitWidth);
422
403
computeKnownBits (OpA, Known, DL, 0 , nullptr , OpA, &DT);
423
- if (Known.Zero . trunc ( BitWidth - 1 ). zext (IdxBitWidth). ult (IdxDiff ))
424
- return false ;
404
+ if (Known.countMaxTrailingOnes () < ( BitWidth - 1 ))
405
+ Safe = true ;
425
406
}
426
407
427
- const SCEV *OffsetSCEVA = SE.getSCEV (ValA);
408
+ if (!Safe)
409
+ return false ;
410
+
411
+ const SCEV *OffsetSCEVA = SE.getSCEV (OpA);
428
412
const SCEV *OffsetSCEVB = SE.getSCEV (OpB);
429
- const SCEV *C = SE.getConstant (IdxDiff. trunc (BitWidth));
430
- const SCEV *X = SE.getAddExpr (OffsetSCEVA, C );
431
- return X == OffsetSCEVB;
413
+ const SCEV *One = SE.getConstant (APInt (BitWidth, 1 ));
414
+ const SCEV *X2 = SE.getAddExpr (OffsetSCEVA, One );
415
+ return X2 == OffsetSCEVB;
432
416
}
433
417
434
418
void Vectorizer::reorder (Instruction *I) {
0 commit comments