Skip to content

[MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations #148198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
"apply_patterns.arm_neon.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector.contract operations should be lowered to
finer-grained vector primitives from the ArmNeon dialect.
Indicates that vector contract operations should be lowered to
to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
}];

let assemblyFormat = "attr-dict";
}

def ApplyArmNeonContractionToBFMMLAPatternsOp
: Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector contract operations should be lowered to
to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
}];

let assemblyFormat = "attr-dict";
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/ArmNeon/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace mlir {
class RewritePatternSet;

namespace arm_neon {
void populateLowerContractionToNeonI8MMPatternPatterns(
RewritePatternSet &patterns);
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
} // namespace arm_neon

} // namespace mlir
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns);

void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);

Expand Down
13 changes: 8 additions & 5 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
populateLowerContractionToSVEI8MMPatterns(patterns);
}
if (armBF16) {
if (armNeon)
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEBFMMLAPatterns(patterns);
}
if (armBF16)
populateLowerContractionToSVEBFMMLAPatterns(patterns);

(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ using namespace mlir;

void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
}

void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRArmNeonTransforms
LowerContractionToNeonI8MMPattern.cpp
LowerContractToNeonPatterns.cpp

DEPENDS
MLIRArmNeonIncGen
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -93,15 +93,20 @@ class VectorContractRewriter {
// multiplications.
enum class MMLA {
Nop,
Signed, // smmla
Unsigned, // ummla
Mixed, // usmmla
MixedSwapped // usmmla with LHS and RHS swapped
SignedInt, // smmla
UnsignedInt, // ummla
MixedInt, // usmmla
Bfloat // bfmmla
};

// Lower-level operation to be emitted.
MMLA mmlaOp = MMLA::Nop;

// Indicate if the operands for the ArmNeon dialect operation need to be
// swapped. Currently this is needed in order to emulate an "summla"
// operation.
bool swapOperands = false;

// The operand tiles. These are not necessarily the operands of
// `vector.contract`, for example they could be operands to `arith.extsi`
// that is in turn fed into `vector.contract`.
Expand All @@ -126,21 +131,22 @@ class VectorContractRewriter {
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
Value lhs, Value rhs) {

if (swapOperands)
std::swap(lhs, rhs);
switch (mmlaOp) {
case MMLA::Signed:
case MMLA::SignedInt:
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
case MMLA::Unsigned:
case MMLA::UnsignedInt:
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
case MMLA::Mixed:
case MMLA::MixedInt:
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
case MMLA::MixedSwapped:
// The accumulator comes transposed and the result will be transposed
// later, so all we have to do here is swap the operands.
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
rhs, lhs);
case MMLA::Bfloat:
return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
rhs);
case MMLA::Nop:
llvm_unreachable("Uninitialized operation type");
}
Expand Down Expand Up @@ -273,7 +279,7 @@ class VectorContractRewriter {
// Transpose ACC if doing signed by unsigned multiplication, because we're
// using the instruction for unsigned by signed multiplication with
// reversed operands.
if (mmlaOp == MMLA::MixedSwapped)
if (swapOperands)
tiledAcc = rewriter.create<vector::TransposeOp>(
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));

Expand Down Expand Up @@ -302,7 +308,7 @@ class VectorContractRewriter {

// Because of the reversed operands the result is obtained transposed.
// Transpose it back,
if (mmlaOp == MMLA::MixedSwapped)
if (swapOperands)
tiledRes = rewriter.create<vector::TransposeOp>(
loc, tiledRes, ArrayRef<int64_t>({1, 0}));

Expand Down Expand Up @@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
// values before the extension. All four signed/unsigned combinations for
// input operands are supported, but they are lowered to different
// operations. Determine which is the appropriate operation to lower to.
mmlaOp = MMLA::Signed;
mmlaOp = MMLA::SignedInt;
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
if (!maybeLhs) {
mmlaOp = MMLA::Unsigned;
mmlaOp = MMLA::UnsignedInt;
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
}
if (!maybeLhs)
Expand All @@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {

auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
if (maybeRhs) {
if (mmlaOp == MMLA::Unsigned)
mmlaOp = MMLA::Mixed;
if (mmlaOp == MMLA::UnsignedInt)
mmlaOp = MMLA::MixedInt;
} else {
if (mmlaOp == MMLA::Signed)
mmlaOp = MMLA::MixedSwapped;
if (mmlaOp == MMLA::SignedInt) {
mmlaOp = MMLA::MixedInt;
swapOperands = true;
}
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}

Expand All @@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
auto lhsExtInType = cast<VectorType>(lhs.getType());
if (lhsExtInType.getElementTypeBitWidth() < 8)
lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
/* signExt */ mmlaOp == MMLA::Signed ||
mmlaOp == MMLA::Mixed,
/* signExt */
(mmlaOp == MMLA::SignedInt ||
(mmlaOp == MMLA::MixedInt && !swapOperands)),
rewriter);

auto rhsExtInType = cast<VectorType>(rhs.getType());
if (rhsExtInType.getElementTypeBitWidth() < 8)

rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
/* signExt */ mmlaOp != MMLA::Unsigned &&
mmlaOp != MMLA::Mixed,
/* signExt */
(mmlaOp == MMLA::SignedInt ||
(mmlaOp == MMLA::MixedInt && swapOperands)),
rewriter);

// Initialize parameters for unrolling.
Expand All @@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
}
};

class VectorContractRewriterBFMMLA : public VectorContractRewriter {
public:
LogicalResult matchAndInit(vector::ContractionOp op,
PatternRewriter &rewriter) {

if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
return failure();

// Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
// tiling.
if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");

// Check the output is a vector of Float32 elements.
auto outTy = dyn_cast<VectorType>(op.getResultType());
if (!outTy || outTy.getElementType() != rewriter.getF32Type())
return rewriter.notifyMatchFailure(op,
"output type is not a vector of f32");

// Check the inputs are vectors of BFloat16 elements.
if (op.getLhsType().getElementType() != rewriter.getBF16Type())
return rewriter.notifyMatchFailure(op,
"input type is not a vector of bf16");

mmlaOp = MMLA::Bfloat;
swapOperands = false;
lhs = op.getLhs();
rhs = op.getRhs();
acc = op.getAcc();

// Initialize parameters for unrolling.
iterationBounds = *op.getShapeForUnroll();
if (iterationBounds.size() == 3)
subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
else
subTileShape = SmallVector<int64_t>({2, 4});

return success();
}
};

/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
Expand All @@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
}
};

class LowerContractionToNeonBFMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {

VectorContractRewriterBFMMLA vcr;
if (failed(vcr.matchAndInit(op, rewriter)))
return failure();
vcr.lower(op, rewriter);

return success();
}
};

} // namespace

void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
}

void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace mlir;

void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
mlir::populateLowerContractionToSVEI8MMPatterns(patterns);
}

void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for Neon. See:
// https://github.com/llvm/llvm-project/issues/145559
// LowerContractionToNeonI8MMPattern.cpp
// LowerContractToNeonPatterns.cpp
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -580,7 +580,7 @@ class LowerContractionToSVEBFMMLAPattern

} // namespace

void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
void mlir::populateLowerContractionToSVEI8MMPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
Expand Down
Loading