Skip to content

Commit 9cf4293

Browse files
committed
Attempt to rewrite VectorCombine
1 parent ce0313c commit 9cf4293

16 files changed

+298
-251
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 118 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,8 @@ class VectorCombine {
107107
const Instruction &I,
108108
ExtractElementInst *&ConvertToShuffle,
109109
unsigned PreferredExtractIndex);
110-
void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
111-
Instruction &I);
112-
void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
113-
Instruction &I);
110+
Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
111+
Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
114112
bool foldExtractExtract(Instruction &I);
115113
bool foldInsExtFNeg(Instruction &I);
116114
bool foldInsExtBinop(Instruction &I);
@@ -138,7 +136,7 @@ class VectorCombine {
138136
bool foldInterleaveIntrinsics(Instruction &I);
139137
bool shrinkType(Instruction &I);
140138

141-
void replaceValue(Value &Old, Value &New) {
139+
void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
142140
LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
143141
LLVM_DEBUG(dbgs() << " With: " << New << '\n');
144142
Old.replaceAllUsesWith(&New);
@@ -147,7 +145,11 @@ class VectorCombine {
147145
Worklist.pushUsersToWorkList(*NewI);
148146
Worklist.pushValue(NewI);
149147
}
150-
Worklist.pushValue(&Old);
148+
if (Erase && isInstructionTriviallyDead(&Old)) {
149+
eraseInstruction(Old);
150+
} else {
151+
Worklist.push(&Old);
152+
}
151153
}
152154

153155
void eraseInstruction(Instruction &I) {
@@ -158,11 +160,23 @@ class VectorCombine {
158160

159161
// Push remaining users of the operands and then the operand itself - allows
160162
// further folds that were hindered by OneUse limits.
161-
for (Value *Op : Ops)
162-
if (auto *OpI = dyn_cast<Instruction>(Op)) {
163-
Worklist.pushUsersToWorkList(*OpI);
164-
Worklist.pushValue(OpI);
163+
SmallPtrSet<Value *, 4> Visited;
164+
for (Value *Op : Ops) {
165+
if (Visited.insert(Op).second) {
166+
if (auto *OpI = dyn_cast<Instruction>(Op)) {
167+
if (RecursivelyDeleteTriviallyDeadInstructions(
168+
OpI, nullptr, nullptr, [this](Value *V) {
169+
if (auto I = dyn_cast<Instruction>(V)) {
170+
LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
171+
Worklist.remove(I);
172+
}
173+
}))
174+
continue;
175+
Worklist.pushUsersToWorkList(*OpI);
176+
Worklist.pushValue(OpI);
177+
}
165178
}
179+
}
166180
}
167181
};
168182
} // namespace
@@ -546,9 +560,8 @@ static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
546560
/// the source vector (shift the scalar element) to a NewIndex for extraction.
547561
/// Return null if the input can be constant folded, so that we are not creating
548562
/// unnecessary instructions.
549-
static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
550-
unsigned NewIndex,
551-
IRBuilderBase &Builder) {
563+
static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
564+
IRBuilderBase &Builder) {
552565
// Shufflevectors can only be created for fixed-width vectors.
553566
Value *X = ExtElt->getVectorOperand();
554567
if (!isa<FixedVectorType>(X->getType()))
@@ -563,52 +576,41 @@ static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
563576

564577
Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
565578
NewIndex, Builder);
566-
return dyn_cast<ExtractElementInst>(
567-
Builder.CreateExtractElement(Shuf, NewIndex));
579+
return Shuf;
568580
}
569581

570582
/// Try to reduce extract element costs by converting scalar compares to vector
571583
/// compares followed by extract.
572584
/// cmp (ext0 V0, C), (ext1 V1, C)
573-
void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
574-
ExtractElementInst *Ext1, Instruction &I) {
585+
Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
586+
Instruction &I) {
575587
assert(isa<CmpInst>(&I) && "Expected a compare");
576-
assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
577-
cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
578-
"Expected matching constant extract indexes");
579588

580589
// cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
581590
++NumVecCmp;
582591
CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
583-
Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
584592
Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
585-
Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
586-
replaceValue(I, *NewExt);
593+
return Builder.CreateExtractElement(VecCmp, ExtIndex, "foldExtExtCmp");
587594
}
588595

589596
/// Try to reduce extract element costs by converting scalar binops to vector
590597
/// binops followed by extract.
591598
/// bo (ext0 V0, C), (ext1 V1, C)
592-
void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
593-
ExtractElementInst *Ext1, Instruction &I) {
599+
Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
600+
Instruction &I) {
594601
assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
595-
assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
596-
cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
597-
"Expected matching constant extract indexes");
598602

599603
// bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
600604
++NumVecBO;
601-
Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
602-
Value *VecBO =
603-
Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
605+
Value *VecBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0,
606+
V1, "foldExtExtBinop");
604607

605608
// All IR flags are safe to back-propagate because any potential poison
606609
// created in unused vector elements is discarded by the extract.
607610
if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
608611
VecBOInst->copyIRFlags(&I);
609612

610-
Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
611-
replaceValue(I, *NewExt);
613+
return Builder.CreateExtractElement(VecBO, ExtIndex, "foldExtExtBinop");
612614
}
613615

614616
/// Match an instruction with extracted vector operands.
@@ -647,25 +649,29 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
647649
if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
648650
return false;
649651

652+
Value *ExtOp0 = Ext0->getVectorOperand();
653+
Value *ExtOp1 = Ext1->getVectorOperand();
654+
650655
if (ExtractToChange) {
651656
unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
652-
ExtractElementInst *NewExtract =
657+
Value *NewExtOp =
653658
translateExtract(ExtractToChange, CheapExtractIdx, Builder);
654-
if (!NewExtract)
659+
if (!NewExtOp)
655660
return false;
656661
if (ExtractToChange == Ext0)
657-
Ext0 = NewExtract;
662+
ExtOp0 = NewExtOp;
658663
else
659-
Ext1 = NewExtract;
664+
ExtOp1 = NewExtOp;
660665
}
661666

662-
if (Pred != CmpInst::BAD_ICMP_PREDICATE)
663-
foldExtExtCmp(Ext0, Ext1, I);
664-
else
665-
foldExtExtBinop(Ext0, Ext1, I);
666-
667+
Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
668+
: Ext0->getIndexOperand();
669+
Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
670+
? foldExtExtCmp(ExtOp0, ExtOp1, ExtIndex, I)
671+
: foldExtExtBinop(ExtOp0, ExtOp1, ExtIndex, I);
667672
Worklist.push(Ext0);
668673
Worklist.push(Ext1);
674+
replaceValue(I, *NewExt);
669675
return true;
670676
}
671677

@@ -1824,7 +1830,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
18241830
LI->getAlign(), VecTy->getElementType(), Idx, *DL);
18251831
NewLoad->setAlignment(ScalarOpAlignment);
18261832

1827-
replaceValue(*EI, *NewLoad);
1833+
replaceValue(*EI, *NewLoad, false);
18281834
}
18291835

18301836
FailureGuard.release();
@@ -2910,7 +2916,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
29102916
if (!IL.first)
29112917
return true;
29122918
Value *V = IL.first->get();
2913-
if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
2919+
if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUser())
29142920
return false;
29152921
if (V->getValueID() != FrontV->getValueID())
29162922
return false;
@@ -3112,7 +3118,7 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
31123118
Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
31133119
LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
31143120
replaceValue(*Shuffle, *NewShuffle);
3115-
MadeChanges = true;
3121+
return true;
31163122
}
31173123

31183124
// See if we can re-use foldSelectShuffle, getting it to reduce the size of
@@ -3608,7 +3614,7 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
36083614
for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
36093615
Builder.SetInsertPoint(Shuffles[S]);
36103616
Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
3611-
replaceValue(*Shuffles[S], *NSV);
3617+
replaceValue(*Shuffles[S], *NSV, false);
36123618
}
36133619

36143620
Worklist.pushValue(NSV0A);
@@ -3873,8 +3879,7 @@ bool VectorCombine::run() {
38733879

38743880
LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
38753881

3876-
bool MadeChange = false;
3877-
auto FoldInst = [this, &MadeChange](Instruction &I) {
3882+
auto FoldInst = [this](Instruction &I) {
38783883
Builder.SetInsertPoint(&I);
38793884
bool IsVectorType = isa<VectorType>(I.getType());
38803885
bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
@@ -3889,10 +3894,12 @@ bool VectorCombine::run() {
38893894
if (IsFixedVectorType) {
38903895
switch (Opcode) {
38913896
case Instruction::InsertElement:
3892-
MadeChange |= vectorizeLoadInsert(I);
3897+
if (vectorizeLoadInsert(I))
3898+
return true;
38933899
break;
38943900
case Instruction::ShuffleVector:
3895-
MadeChange |= widenSubvectorLoad(I);
3901+
if (widenSubvectorLoad(I))
3902+
return true;
38963903
break;
38973904
default:
38983905
break;
@@ -3902,19 +3909,25 @@ bool VectorCombine::run() {
39023909
// This transform works with scalable and fixed vectors
39033910
// TODO: Identify and allow other scalable transforms
39043911
if (IsVectorType) {
3905-
MadeChange |= scalarizeOpOrCmp(I);
3906-
MadeChange |= scalarizeLoadExtract(I);
3907-
MadeChange |= scalarizeExtExtract(I);
3908-
MadeChange |= scalarizeVPIntrinsic(I);
3909-
MadeChange |= foldInterleaveIntrinsics(I);
3912+
if (scalarizeOpOrCmp(I))
3913+
return true;
3914+
if (scalarizeLoadExtract(I))
3915+
return true;
3916+
if (scalarizeExtExtract(I))
3917+
return true;
3918+
if (scalarizeVPIntrinsic(I))
3919+
return true;
3920+
if (foldInterleaveIntrinsics(I))
3921+
return true;
39103922
}
39113923

39123924
if (Opcode == Instruction::Store)
3913-
MadeChange |= foldSingleElementStore(I);
3925+
if (foldSingleElementStore(I))
3926+
return true;
39143927

39153928
// If this is an early pipeline invocation of this pass, we are done.
39163929
if (TryEarlyFoldsOnly)
3917-
return;
3930+
return false;
39183931

39193932
// Otherwise, try folds that improve codegen but may interfere with
39203933
// early IR canonicalizations.
@@ -3923,72 +3936,91 @@ bool VectorCombine::run() {
39233936
if (IsFixedVectorType) {
39243937
switch (Opcode) {
39253938
case Instruction::InsertElement:
3926-
MadeChange |= foldInsExtFNeg(I);
3927-
MadeChange |= foldInsExtBinop(I);
3928-
MadeChange |= foldInsExtVectorToShuffle(I);
3939+
if (foldInsExtFNeg(I))
3940+
return true;
3941+
if (foldInsExtBinop(I))
3942+
return true;
3943+
if (foldInsExtVectorToShuffle(I))
3944+
return true;
39293945
break;
39303946
case Instruction::ShuffleVector:
3931-
MadeChange |= foldPermuteOfBinops(I);
3932-
MadeChange |= foldShuffleOfBinops(I);
3933-
MadeChange |= foldShuffleOfSelects(I);
3934-
MadeChange |= foldShuffleOfCastops(I);
3935-
MadeChange |= foldShuffleOfShuffles(I);
3936-
MadeChange |= foldShuffleOfIntrinsics(I);
3937-
MadeChange |= foldSelectShuffle(I);
3938-
MadeChange |= foldShuffleToIdentity(I);
3947+
if (foldPermuteOfBinops(I))
3948+
return true;
3949+
if (foldShuffleOfBinops(I))
3950+
return true;
3951+
if (foldShuffleOfSelects(I))
3952+
return true;
3953+
if (foldShuffleOfCastops(I))
3954+
return true;
3955+
if (foldShuffleOfShuffles(I))
3956+
return true;
3957+
if (foldShuffleOfIntrinsics(I))
3958+
return true;
3959+
if (foldSelectShuffle(I))
3960+
return true;
3961+
if (foldShuffleToIdentity(I))
3962+
return true;
39393963
break;
39403964
case Instruction::BitCast:
3941-
MadeChange |= foldBitcastShuffle(I);
3965+
if (foldBitcastShuffle(I))
3966+
return true;
39423967
break;
39433968
case Instruction::And:
39443969
case Instruction::Or:
39453970
case Instruction::Xor:
3946-
MadeChange |= foldBitOpOfCastops(I);
3971+
if (foldBitOpOfCastops(I))
3972+
return true;
39473973
break;
39483974
default:
3949-
MadeChange |= shrinkType(I);
3975+
if (shrinkType(I))
3976+
return true;
39503977
break;
39513978
}
39523979
} else {
39533980
switch (Opcode) {
39543981
case Instruction::Call:
3955-
MadeChange |= foldShuffleFromReductions(I);
3956-
MadeChange |= foldCastFromReductions(I);
3982+
if (foldShuffleFromReductions(I))
3983+
return true;
3984+
if (foldCastFromReductions(I))
3985+
return true;
39573986
break;
39583987
case Instruction::ICmp:
39593988
case Instruction::FCmp:
3960-
MadeChange |= foldExtractExtract(I);
3989+
if (foldExtractExtract(I))
3990+
return true;
39613991
break;
39623992
case Instruction::Or:
3963-
MadeChange |= foldConcatOfBoolMasks(I);
3993+
if (foldConcatOfBoolMasks(I))
3994+
return true;
39643995
[[fallthrough]];
39653996
default:
39663997
if (Instruction::isBinaryOp(Opcode)) {
3967-
MadeChange |= foldExtractExtract(I);
3968-
MadeChange |= foldExtractedCmps(I);
3969-
MadeChange |= foldBinopOfReductions(I);
3998+
if (foldExtractExtract(I))
3999+
return true;
4000+
if (foldExtractedCmps(I))
4001+
return true;
4002+
if (foldBinopOfReductions(I))
4003+
return true;
39704004
}
39714005
break;
39724006
}
39734007
}
4008+
return false;
39744009
};
39754010

3976-
SmallVector<Instruction*, 128> InstrsForInstructionWorklist;
4011+
bool MadeChange = false;
39774012
for (BasicBlock &BB : F) {
39784013
// Ignore unreachable basic blocks.
39794014
if (!DT.isReachableFromEntry(&BB))
39804015
continue;
3981-
for (Instruction &I : BB) {
4016+
// Use early increment range so that we can erase instructions in loop.
4017+
for (Instruction &I : make_early_inc_range(BB)) {
39824018
if (I.isDebugOrPseudoInst())
39834019
continue;
3984-
InstrsForInstructionWorklist.push_back(&I);
4020+
MadeChange |= FoldInst(I);
39854021
}
39864022
}
39874023

3988-
Worklist.reserve(InstrsForInstructionWorklist.size());
3989-
for (auto I : reverse(InstrsForInstructionWorklist))
3990-
Worklist.push(I);
3991-
39924024
while (!Worklist.isEmpty()) {
39934025
Instruction *I = Worklist.removeOne();
39944026
if (!I)
@@ -3999,7 +4031,7 @@ bool VectorCombine::run() {
39994031
continue;
40004032
}
40014033

4002-
FoldInst(*I);
4034+
MadeChange |= FoldInst(*I);
40034035
}
40044036

40054037
return MadeChange;

0 commit comments

Comments
 (0)