Skip to content

[LV] Use ExtractLane(LastActiveLane, V) live outs when tail-folding. (WIP) #149042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1929,24 +1929,6 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
for (const auto &Reduction : getReductionVars())
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());

// TODO: handle non-reduction outside users when tail is folded by masking.
for (auto *AE : AllowedExit) {
// Check that all users of allowed exit values are inside the loop or
// are the live-out of a reduction.
if (ReductionLiveOuts.count(AE))
continue;
for (User *U : AE->users()) {
Instruction *UI = cast<Instruction>(U);
if (TheLoop->contains(UI))
continue;
LLVM_DEBUG(
dbgs()
<< "LV: Cannot fold tail by masking, loop has an outside user for "
<< *UI << "\n");
return false;
}
}

for (const auto &Entry : getInductionVars()) {
PHINode *OrigPhi = Entry.first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the PR handle the IV liveout user as well?

for (User *U : OrigPhi->users()) {
Expand Down
31 changes: 20 additions & 11 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8446,7 +8446,9 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
/// exit block. The penultimate value of recurrences is fed to their LCSSA phi
/// users in the original exit block using the VPIRInstruction wrapping to the
/// LCSSA phi.
static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
static bool addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
using namespace llvm::VPlanPatternMatch;

VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
auto *ScalarPHVPBB = Plan.getScalarPreheader();
auto *MiddleVPBB = Plan.getMiddleBlock();
Expand All @@ -8465,6 +8467,15 @@ static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
"Cannot handle loops with uncountable early exits");

// TODO: Support ExtractLane of last-active-lane with first-order
// recurrences.

if (any_of(FOR->users(), [FOR](VPUser *U) {
return match(U, m_VPInstruction<VPInstruction::ExtractLane>(
m_VPValue(), m_Specific(FOR)));
}))
return false;

// This is the second phase of vectorizing first-order recurrences, creating
// extract for users outside the loop. An overview of the transformation is
// described below. Suppose we have the following loop with some use after
Expand Down Expand Up @@ -8536,24 +8547,25 @@ static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
// Extract the penultimate value of the recurrence and use it as operand for
// the VPIRInstruction modeling the phi.
for (VPUser *U : FOR->users()) {
using namespace llvm::VPlanPatternMatch;
if (!match(U, m_VPInstruction<VPInstruction::ExtractLastElement>(
m_Specific(FOR))))
continue;

// For VF vscale x 1, if vscale = 1, we are unable to extract the
// penultimate value of the recurrence. Instead we rely on the existing
// extract of the last element from the result of
// VPInstruction::FirstOrderRecurrenceSplice.
// TODO: Consider vscale_range info and UF.
if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne,
Range))
return;
return true;
VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
{}, "vector.recur.extract.for.phi");
cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
}
}
return true;
}

VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
Expand Down Expand Up @@ -8758,7 +8770,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
R->setOperand(1, WideIV->getStepValue());
}

addExitUsersForFirstOrderRecurrences(*Plan, Range);
if (!addExitUsersForFirstOrderRecurrences(*Plan, Range))
return nullptr;
DenseMap<VPValue *, VPValue *> IVEndValues;
addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);

Expand Down Expand Up @@ -9170,7 +9183,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
continue;
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
if (match(U, m_VPInstruction<VPInstruction::ExtractLastElement>(
m_VPValue())))
m_VPValue())) ||
match(U, m_VPInstruction<VPInstruction::ExtractLane>(m_VPValue(),
m_VPValue())))
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
}

Expand Down Expand Up @@ -10022,12 +10037,6 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Get user vectorization factor and interleave count.
ElementCount UserVF = Hints.getWidth();
unsigned UserIC = Hints.getInterleave();
if (LVL.hasUncountableEarlyExit() && UserIC != 1) {
UserIC = 1;
reportVectorizationInfo("Interleaving not supported for loops "
"with uncountable early exits",
"InterleaveEarlyExitDisabled", ORE, L);
}

// Plan how to best vectorize.
LVP.plan(UserVF, UserIC);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,10 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
ReductionStartVector,
// Creates a step vector starting from 0 to VF with a step of 1.
StepVector,
/// Extracts a single lane (first operand) from a set of vector operands.
/// The lane specifies an index into a vector formed by combining all vector
/// operands (all operands after the first one).
ExtractLane,

};

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
case VPInstruction::BuildStructVector:
case VPInstruction::BuildVector:
return SetResultTyFromOp();
case VPInstruction::ExtractLane:
return inferScalarType(R->getOperand(1));
case VPInstruction::FirstActiveLane:
return Type::getIntNTy(Ctx, 64);
case VPInstruction::ExtractLastElement:
Expand Down
54 changes: 49 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#include "VPRecipeBuilder.h"
#include "VPlan.h"
#include "VPlanCFG.h"
#include "VPlanPatternMatch.h"
#include "VPlanTransforms.h"
#include "VPlanUtils.h"
#include "llvm/ADT/PostOrderIterator.h"

using namespace llvm;
using namespace VPlanPatternMatch;

namespace {
class VPPredicator {
Expand All @@ -42,11 +44,6 @@ class VPPredicator {
/// possibly inserting new recipes at \p Dst (using Builder's insertion point)
VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);

/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
return BlockMaskCache.lookup(VPBB);
}

/// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
/// already have a mask.
void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
Expand All @@ -66,6 +63,11 @@ class VPPredicator {
}

public:
/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
return BlockMaskCache.lookup(VPBB);
}

/// Returns the precomputed predicate of the edge from \p Src to \p Dst.
VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
return EdgeMaskCache.lookup({Src, Dst});
Expand Down Expand Up @@ -300,5 +302,47 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {

PrevVPBB = VPBB;
}

// If we folded the tail and introduced a header mask, any extract of the last
// element must be updated to only extract the last-active-lane of the header
// mask.
if (FoldTail) {
assert(Plan.getExitBlocks().size() == 1 &&
"only a single-exit block is supported currently");
VPBasicBlock *EB = Plan.getExitBlocks().front();
assert(EB->getSinglePredecessor() == Plan.getMiddleBlock() &&
"the exit block must have middle block as single predecessor");

VPValue *LastActiveLane = nullptr;
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
for (auto &P : EB->phis()) {
auto *ExitIRI = cast<VPIRPhi>(&P);
VPValue *Inc = ExitIRI->getIncomingValue(0);
VPValue *Op;
if (!match(Inc, m_VPInstruction<VPInstruction::ExtractLastElement>(
m_VPValue(Op))))
continue;

if (!LastActiveLane) {
// Compute the index of the last active lane, by getting the
// first-active-lane of the negated header mask (which is the first lane
// the original header mask was false) and subtract 1.
VPValue *HeaderMask = Predicator.getBlockInMask(
Plan.getVectorLoopRegion()->getEntryBasicBlock());
LastActiveLane = B.createNaryOp(
Instruction::Sub,
{B.createNaryOp(VPInstruction::FirstActiveLane,
{B.createNot(HeaderMask)}),
Plan.getOrAddLiveIn(ConstantInt::get(
IntegerType::get(
Plan.getScalarHeader()->getIRBasicBlock()->getContext(),
64),
1))});
}
auto *Ext =
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
Inc->replaceAllUsesWith(Ext);
}
}
return Predicator.getBlockMaskCache();
}
48 changes: 45 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,31 @@ Value *VPInstruction::generate(VPTransformState &State) {
Res = Builder.CreateOr(Res, State.get(Op));
return Builder.CreateOrReduce(Res);
}
case VPInstruction::ExtractLane: {
Value *LaneToExtract = State.get(getOperand(0), true);
Type *IdxTy = State.TypeAnalysis.inferScalarType(getOperand(0));
Value *Res = nullptr;
Value *RuntimeVF = getRuntimeVF(State.Builder, IdxTy, State.VF);

for (unsigned Idx = 1; Idx != getNumOperands(); ++Idx) {
Value *VectorStart =
Builder.CreateMul(RuntimeVF, ConstantInt::get(IdxTy, Idx - 1));
Value *VectorIdx = Idx == 1
? LaneToExtract
: Builder.CreateSub(LaneToExtract, VectorStart);
Value *Ext = State.VF.isScalar()
? State.get(getOperand(Idx))
: Builder.CreateExtractElement(
State.get(getOperand(Idx)), VectorIdx);
if (Res) {
Value *Cmp = Builder.CreateICmpUGE(LaneToExtract, VectorStart);
Res = Builder.CreateSelect(Cmp, Ext, Res);
} else {
Res = Ext;
}
}
return Res;
}
case VPInstruction::FirstActiveLane: {
if (getNumOperands() == 1) {
Value *Mask = State.get(getOperand(0));
Expand All @@ -876,8 +901,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
unsigned LastOpIdx = getNumOperands() - 1;
Value *Res = nullptr;
for (int Idx = LastOpIdx; Idx >= 0; --Idx) {
Value *TrailingZeros = Builder.CreateCountTrailingZeroElems(
Builder.getInt64Ty(), State.get(getOperand(Idx)), true, Name);
Value *TrailingZeros =
State.VF.isScalar()
? Builder.CreateZExt(
Builder.CreateICmpEQ(State.get(getOperand(Idx)),
Builder.getInt1(0)),
Builder.getInt64Ty())
: Builder.CreateCountTrailingZeroElems(
// Value *TrailingZeros =
// Builder.CreateCountTrailingZeroElems(
Builder.getInt64Ty(), State.get(getOperand(Idx)), true,
Name);
Value *Current = Builder.CreateAdd(
Builder.CreateMul(RuntimeVF, Builder.getInt64(Idx)), TrailingZeros);
if (Res) {
Expand Down Expand Up @@ -920,7 +954,8 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
}

switch (getOpcode()) {
case Instruction::ExtractElement: {
case Instruction::ExtractElement:
case VPInstruction::ExtractLane: {
// Add on the cost of extracting the element.
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
return Ctx.TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy,
Expand Down Expand Up @@ -982,6 +1017,7 @@ bool VPInstruction::isVectorToScalar() const {
return getOpcode() == VPInstruction::ExtractLastElement ||
getOpcode() == VPInstruction::ExtractPenultimateElement ||
getOpcode() == Instruction::ExtractElement ||
getOpcode() == VPInstruction::ExtractLane ||
getOpcode() == VPInstruction::FirstActiveLane ||
getOpcode() == VPInstruction::ComputeAnyOfResult ||
getOpcode() == VPInstruction::ComputeFindIVResult ||
Expand Down Expand Up @@ -1040,6 +1076,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case VPInstruction::BuildVector:
case VPInstruction::CalculateTripCountMinusVF:
case VPInstruction::CanonicalIVIncrementForPart:
case VPInstruction::ExtractLane:
case VPInstruction::ExtractLastElement:
case VPInstruction::ExtractPenultimateElement:
case VPInstruction::FirstActiveLane:
Expand Down Expand Up @@ -1088,6 +1125,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
case VPInstruction::ComputeAnyOfResult:
case VPInstruction::ComputeFindIVResult:
return Op == getOperand(1);
case VPInstruction::ExtractLane:
return Op == getOperand(0);
};
llvm_unreachable("switch should return");
}
Expand Down Expand Up @@ -1166,6 +1205,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::BuildVector:
O << "buildvector";
break;
case VPInstruction::ExtractLane:
O << "extract-lane";
break;
case VPInstruction::ExtractLastElement:
O << "extract-last-element";
break;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,10 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan,
using namespace VPlanPatternMatch;

VPValue *Incoming, *Mask;
if (!match(Op, m_VPInstruction<Instruction::ExtractElement>(
m_VPValue(Incoming),
if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>(
m_VPInstruction<VPInstruction::FirstActiveLane>(
m_VPValue(Mask)))))
m_VPValue(Mask)),
m_VPValue(Incoming))))
return nullptr;

auto *WideIV = getOptimizableIVOf(Incoming);
Expand Down Expand Up @@ -2831,7 +2831,7 @@ void VPlanTransforms::handleUncountableEarlyExit(
VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
"first.active.lane");
IncomingFromEarlyExit = EarlyExitB.createNaryOp(
Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
VPInstruction::ExtractLane, {FirstActiveLane, IncomingFromEarlyExit},
nullptr, "early.exit.value");
ExitIRI->setOperand(EarlyExitIdx, IncomingFromEarlyExit);
}
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,13 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
continue;
}
VPValue *Op0;
if (match(&R, m_VPInstruction<VPInstruction::ExtractLane>(
m_VPValue(Op0), m_VPValue(Op1)))) {
addUniformForAllParts(cast<VPInstruction>(&R));
for (unsigned Part = 1; Part != UF; ++Part)
R.addOperand(getValueForPart(Op1, Part));
continue;
}
if (match(&R, m_VPInstruction<VPInstruction::ExtractLastElement>(
m_VPValue(Op0))) ||
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
Expand Down
Loading
Loading