From 950a1853e7bd5076b0c561ea74ae365f7460c442 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 21 Jul 2025 18:22:07 -0400 Subject: [PATCH] [mlir][NFC] update `mlir/Dialect` create APIs (24/n) See https://github.com/llvm/llvm-project/pull/147168 for more info. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 111 +++---- .../BufferizableOpInterfaceImpl.cpp | 9 +- .../Vector/Transforms/LowerVectorBitCast.cpp | 8 +- .../Transforms/LowerVectorBroadcast.cpp | 22 +- .../Vector/Transforms/LowerVectorContract.cpp | 130 ++++---- .../Vector/Transforms/LowerVectorGather.cpp | 50 +-- .../Transforms/LowerVectorInterleave.cpp | 30 +- .../Vector/Transforms/LowerVectorMask.cpp | 37 +-- .../Transforms/LowerVectorMultiReduction.cpp | 70 +++-- .../Vector/Transforms/LowerVectorScan.cpp | 30 +- .../Transforms/LowerVectorShapeCast.cpp | 64 ++-- ...LowerVectorToFromElementsToShuffleTree.cpp | 4 +- .../Vector/Transforms/LowerVectorTransfer.cpp | 52 ++-- .../Transforms/LowerVectorTranspose.cpp | 32 +- .../Vector/Transforms/VectorDistribute.cpp | 134 ++++---- .../Transforms/VectorDropLeadUnitDim.cpp | 71 +++-- .../VectorEmulateMaskedLoadStore.cpp | 36 ++- .../Transforms/VectorEmulateNarrowType.cpp | 290 +++++++++--------- ...sertExtractStridedSliceRewritePatterns.cpp | 48 +-- .../Vector/Transforms/VectorLinearize.cpp | 11 +- .../Transforms/VectorMaskElimination.cpp | 4 +- .../Transforms/VectorTransferOpTransforms.cpp | 48 +-- .../VectorTransferSplitRewritePatterns.cpp | 102 +++--- .../Vector/Transforms/VectorTransforms.cpp | 136 ++++---- .../Vector/Transforms/VectorUnroll.cpp | 64 ++-- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 8 +- 26 files changed, 825 insertions(+), 776 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 56f748fbbe1d6..4c00fb58e4d30 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) { /// Default callback to build a region with a 'vector.yield' terminator with no /// arguments. void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { - builder.create(loc); + vector::YieldOp::create(builder, loc); } // Helper for verifying combining kinds in contractions and reductions. @@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction VectorType newMaskType = VectorType::get(dstVecType.getShape(), rewriter.getI1Type(), dstVecType.getScalableDims()); - mask = rewriter.create(loc, newMaskType, mask); + mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask); } - cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource()); + cast = vector::ShapeCastOp::create( + rewriter, loc, reductionOp.getDestType(), reductionOp.getSource()); } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. if (mask) - mask = rewriter.create(loc, mask); - cast = rewriter.create(loc, reductionOp.getSource()); + mask = vector::ExtractOp::create(rewriter, loc, mask); + cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource()); } Value result = @@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, switch (op) { case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::addi: - return builder.create(vector.getLoc(), - CombiningKind::ADD, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::ADD, vector); case arith::AtomicRMWKind::mulf: case arith::AtomicRMWKind::muli: - return builder.create(vector.getLoc(), - CombiningKind::MUL, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MUL, vector); case arith::AtomicRMWKind::minimumf: - return builder.create(vector.getLoc(), - CombiningKind::MINIMUMF, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINIMUMF, vector); case arith::AtomicRMWKind::mins: - return builder.create(vector.getLoc(), - CombiningKind::MINSI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINSI, vector); case arith::AtomicRMWKind::minu: - return builder.create(vector.getLoc(), - CombiningKind::MINUI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINUI, vector); case arith::AtomicRMWKind::maximumf: - return builder.create(vector.getLoc(), - CombiningKind::MAXIMUMF, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXIMUMF, vector); case arith::AtomicRMWKind::maxs: - return builder.create(vector.getLoc(), - CombiningKind::MAXSI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXSI, vector); case arith::AtomicRMWKind::maxu: - return builder.create(vector.getLoc(), - CombiningKind::MAXUI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXUI, vector); case arith::AtomicRMWKind::andi: - return builder.create(vector.getLoc(), - CombiningKind::AND, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::AND, vector); case arith::AtomicRMWKind::ori: - return builder.create(vector.getLoc(), - CombiningKind::OR, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::OR, vector); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); @@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern { Location loc = reductionOp.getLoc(); if (mask) - mask = rewriter.create(loc, mask); - Value result = rewriter.create(loc, reductionOp.getVector()); + mask = ExtractOp::create(rewriter, loc, mask); + Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector()); if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), @@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final // greater than the vector dim size. IntegerAttr offsetAttr = rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset); - Value offset = rewriter.create(loc, offsetAttr); + Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr); Value sliceMaskDimSize = - rewriter.create(loc, maskDimSize, offset); + arith::SubIOp::create(rewriter, loc, maskDimSize, offset); sliceMaskDimSizes.push_back(sliceMaskDimSize); } // Add unchanged dimensions. @@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final sizes[i] = 1; } } - source = rewriter.create( - op->getLoc(), source, offsets, sizes, + source = ExtractStridedSliceOp::create( + rewriter, op->getLoc(), source, offsets, sizes, getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff)); } rewriter.replaceOpWithNewOp(op, op.getType(), source); @@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final SmallVector offsets = getI64SubArray(op.getOffsets()); auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); - Value extract = rewriter.create(op->getLoc(), source, - extractOffsets); + Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source, + extractOffsets); rewriter.replaceOpWithNewOp(op, op.getType(), extract); return success(); } @@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, *padding, /*mask=*/Value(), inBoundsAttr); } @@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, SmallVector(vectorType.getRank(), false)); Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, *padding, permutationMapAttr, inBoundsAttr); } @@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, SmallVector(vectorType.getRank(), false)); Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, *padding, /*mask=*/Value(), inBoundsAttr); @@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast VectorType broadcastedType = VectorType::get( broadcastShape, defWrite.getVectorType().getElementType(), broadcastScalableFlags); - vec = rewriter.create(loc, broadcastedType, vec); + vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec); SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(readOp, vec, transposePerm); @@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. // Set all in_bounds to false and let the folder infer them. SmallVector newInBounds(vectorShape.size(), false); - auto newExtractOp = rewriter.create( - extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(), - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), - insertOp.getMixedStrides()); - auto newTransferWriteOp = rewriter.create( - transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), - transferOp.getIndices(), transferOp.getPermutationMapAttr(), + auto newExtractOp = tensor::ExtractSliceOp::create( + rewriter, extractOp.getLoc(), insertOp.getSourceType(), + insertOp.getDest(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + auto newTransferWriteOp = TransferWriteOp::create( + rewriter, transferOp.getLoc(), transferOp.getVector(), + newExtractOp.getResult(), transferOp.getIndices(), + transferOp.getPermutationMapAttr(), rewriter.getBoolArrayAttr(newInBounds)); rewriter.modifyOpInPlace(insertOp, [&]() { insertOp.getSourceMutable().assign(newTransferWriteOp.getResult()); @@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { OpBuilder opBuilder(builder.getContext()); Operation *maskedOp = &block.front(); opBuilder.setInsertionPointToEnd(&block); - opBuilder.create(loc, maskedOp->getResults()); + vector::YieldOp::create(opBuilder, loc, maskedOp->getResults()); } LogicalResult MaskOp::verify() { @@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder, // Create a block and move the op to that block. insBlock->getOperations().splice( insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp); - builder.create(maskableOp->getLoc(), maskableOp->getResults()); + YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults()); } /// Creates a vector.mask operation around a maskable operation. Returns the @@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder, if (!mask) return maskableOp; if (passthru) - return builder.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, passthru, - maskableOp, createMaskOpRegion); - return builder.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, maskableOp, - createMaskOpRegion); + return MaskOp::create(builder, maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, passthru, + maskableOp, createMaskOpRegion); + return MaskOp::create(builder, maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); } /// Creates a vector select operation that picks values from `newValue` or @@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, if (!mask) return newValue; - return builder.create(newValue.getLoc(), newValue.getType(), - mask, newValue, passthru); + return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(), + mask, newValue, passthru); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 9da051150e409..66196194b0585 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -116,8 +116,8 @@ struct TransferWriteOpInterface getBuffer(rewriter, writeOp.getBase(), options, state); if (failed(resultBuffer)) return failure(); - rewriter.create( - writeOp.getLoc(), writeOp.getVector(), *resultBuffer, + vector::TransferWriteOp::create( + rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), writeOp.getMask(), writeOp.getInBoundsAttr()); replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); @@ -241,8 +241,9 @@ struct MaskOpInterface // Create a new vector.mask op. ValueRange newYieldedValuesRange(newYieldedValues); TypeRange newResultTypes(newYieldedValuesRange); - auto newOp = rewriter.create( - op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(), + auto newOp = vector::MaskOp::create( + rewriter, op->getLoc(), newResultTypes, maskOp.getMask(), + maskOp.getPassthru(), /*maskableOp=*/nullptr, /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {}); newOp.getRegion().takeBody(maskOp.getMaskRegion()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp index 89930a6bd35fa..4c3a04cfb5bfa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp @@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern { VectorType::get(shape, resultType.getElementType(), scalableDims); Location loc = op.getLoc(); - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); for (auto position : *unrollIterator) { Value extract = - rewriter.create(loc, op.getSource(), position); + vector::ExtractOp::create(rewriter, loc, op.getSource(), position); Value bitcast = - rewriter.create(loc, bitcastResType, extract); + vector::BitCastOp::create(rewriter, loc, bitcastResType, extract); result = - rewriter.create(loc, bitcast, result, position); + vector::InsertOp::create(rewriter, loc, bitcast, result, position); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index 11dcfe421e0c4..cb8e566869cfd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern { // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. if (srcRank <= 1 && dstRank == 1) { - Value ext = rewriter.create(loc, op.getSource()); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } @@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern { // Duplication. VectorType resType = VectorType::Builder(dstType).dropDim(0); Value bcst = - rewriter.create(loc, resType, op.getSource()); - Value result = rewriter.create(loc, dstType); + vector::BroadcastOp::create(rewriter, loc, resType, op.getSource()); + Value result = ub::PoisonOp::create(rewriter, loc, dstType); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); rewriter.replaceOp(op, result); return success(); } @@ -111,13 +111,13 @@ class BroadcastOpLowering : public OpRewritePattern { VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType, dstType.getScalableDims().drop_front()); - Value result = rewriter.create(loc, dstType); + Value result = ub::PoisonOp::create(rewriter, loc, dstType); if (m == 0) { // Stetch at start. - Value ext = rewriter.create(loc, op.getSource(), 0); - Value bcst = rewriter.create(loc, resType, ext); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0); + Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); } else { // Stetch not at start. if (dstType.getScalableDims()[0]) { @@ -125,9 +125,9 @@ class BroadcastOpLowering : public OpRewritePattern { return failure(); } for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { - Value ext = rewriter.create(loc, op.getSource(), d); - Value bcst = rewriter.create(loc, resType, ext); - result = rewriter.create(loc, bcst, result, d); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d); + Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); } } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index fc6c90f5132c7..65702ffa152d9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -81,17 +81,17 @@ static Value reshapeLoad(Location loc, Value val, VectorType type, // At extraction dimension? if (index == 0) - return rewriter.create(loc, val, pos); + return vector::ExtractOp::create(rewriter, loc, val, pos); // Unroll leading dimensions. VectorType vType = VectorType::Builder(type).dropDim(0); VectorType resType = VectorType::Builder(type).dropDim(index); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, val, d); + Value ext = vector::ExtractOp::create(rewriter, loc, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, load, result, d); + result = vector::InsertOp::create(rewriter, loc, load, result, d); } return result; } @@ -106,15 +106,15 @@ static Value reshapeStore(Location loc, Value val, Value result, return val; // At insertion dimension? if (index == 0) - return rewriter.create(loc, val, result, pos); + return vector::InsertOp::create(rewriter, loc, val, result, pos); // Unroll leading dimensions. VectorType vType = VectorType::Builder(type).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, result, d); - Value ins = rewriter.create(loc, val, d); + Value ext = vector::ExtractOp::create(rewriter, loc, result, d); + Value ins = vector::ExtractOp::create(rewriter, loc, val, d); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, sto, result, d); + result = vector::InsertOp::create(rewriter, loc, sto, result, d); } return result; } @@ -132,7 +132,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) // Only valid for floating point types. return std::nullopt; - mul = rewriter.create(loc, x, y); + mul = arith::MulIOp::create(rewriter, loc, x, y); } else { // Float case. if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || @@ -143,14 +143,14 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, return std::nullopt; // Special case for fused multiply-add. if (acc && isa(acc.getType()) && kind == CombiningKind::ADD) { - Value fma = rewriter.create(loc, x, y, acc); + Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc); if (mask) // The fma op doesn't need explicit masking. However, fma ops used in // reductions must preserve previous 'acc' values for masked-out lanes. fma = selectPassthru(rewriter, mask, fma, acc); return fma; } - mul = rewriter.create(loc, x, y); + mul = arith::MulFOp::create(rewriter, loc, x, y); } if (!acc) @@ -186,8 +186,8 @@ static std::optional getDimPosition(AffineMap map, unsigned dim) { static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return arith::AddIOp::create(rewriter, loc, x, y); + return arith::AddFOp::create(rewriter, loc, x, y); } /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using @@ -195,8 +195,8 @@ static Value createAdd(Location loc, Value x, Value y, bool isInt, static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return arith::MulIOp::create(rewriter, loc, x, y); + return arith::MulFOp::create(rewriter, loc, x, y); } namespace { @@ -359,7 +359,7 @@ struct UnrolledOuterProductGenerator Value t(Value v, ArrayRef perm = {1, 0}) { if (!v) return v; - return rewriter.create(loc, v, perm); + return vector::TransposeOp::create(rewriter, loc, v, perm); } Value promote(Value v, Type dstElementType) { @@ -373,8 +373,8 @@ struct UnrolledOuterProductGenerator if (vecType) promotedType = vecType.clone(promotedType); if (isa(dstElementType)) - return rewriter.create(loc, promotedType, v); - return rewriter.create(loc, promotedType, v); + return arith::ExtFOp::create(rewriter, loc, promotedType, v); + return arith::ExtSIOp::create(rewriter, loc, promotedType, v); } FailureOr outerProd(Value lhs, Value rhs, Value res, @@ -386,17 +386,17 @@ struct UnrolledOuterProductGenerator Type resElementType = cast(res.getType()).getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { - Value extractA = rewriter.create(loc, lhs, k); - Value extractB = rewriter.create(loc, rhs, k); + Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k); + Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k); extractA = promote(extractA, resElementType); extractB = promote(extractB, resElementType); Value extractMask; if (maybeMask.has_value() && maybeMask.value()) extractMask = - rewriter.create(loc, maybeMask.value(), k); + vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k); - Operation *outerProdOp = rewriter.create( - loc, res.getType(), extractA, extractB, res, kind); + Operation *outerProdOp = vector::OuterProductOp::create( + rewriter, loc, res.getType(), extractA, extractB, res, kind); res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); } return res; @@ -646,28 +646,28 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( // Two outer parallel, one inner reduction (matmat flavor). // if (maps == infer({{m, k}, {k, n}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { // No need to permute anything. } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); - rhs = rewriter.create(loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { // This is the classical row-major matmul. Just permute the lhs. Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); rhs = tmp; } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = rewriter.create(loc, tmp, perm); + lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); lhs = tmp; } else { return failure(); @@ -680,12 +680,12 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( if (maps == infer({{m, n}, {n}, {m}})) { // No need to permute anything. } else if (maps == infer({{n, m}, {n}, {m}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{n}, {m, n}, {m}})) { std::swap(lhs, rhs); } else if (maps == infer({{n}, {n, m}, {m}})) { std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else { return failure(); } @@ -702,31 +702,32 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; // ExtractOp does not allow dynamic indexing, we must unroll explicitly. - Value res = rewriter.create(loc, dstType, - rewriter.getZeroAttr(dstType)); + Value res = arith::ConstantOp::create(rewriter, loc, dstType, + rewriter.getZeroAttr(dstType)); bool isInt = isa(dstType.getElementType()); llvm::SmallVector extractedCols; extractedCols.reserve(dstColumns); for (unsigned r = 0; r < dstRows; ++r) { - Value rowLhs = rewriter.create(op.getLoc(), lhs, r); + Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { // Extract each respective row and column of the LHS and RHS once to // avoid having duplicate SSA values pointing to the same rows/columns. if (r == 0) { Value colRhs = - rank == 1 ? rhs - : rewriter.create(op.getLoc(), rhs, c); + rank == 1 + ? rhs + : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c); extractedCols.push_back(colRhs); } Value extractedColRhs = extractedCols[c]; Value product = createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter); - Value sum = rewriter.create( - op.getLoc(), vector::CombiningKind::ADD, product); + Value sum = vector::ReductionOp::create( + rewriter, op.getLoc(), vector::CombiningKind::ADD, product); SmallVector pos = rank == 1 ? SmallVector{r} : SmallVector{r, c}; - res = rewriter.create(op.getLoc(), sum, res, pos); + res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos); } } if (auto acc = op.getAcc()) @@ -827,21 +828,21 @@ struct ContractOpToElementwise lhsDims.append(lhsShape.begin(), lhsShape.end()); auto expandedType = VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); - newLhs = rewriter.create(loc, expandedType, newLhs); + newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs); } if (!rhsDims.empty()) { rhsDims.append(rhsShape.begin(), rhsShape.end()); auto expandedType = VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); - newRhs = rewriter.create(loc, expandedType, newRhs); + newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs); } bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); - newLhs = rewriter.create(loc, newLhs, lhsTranspose); - newRhs = rewriter.create(loc, newRhs, rhsTranspose); + newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose); + newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose); SmallVector lhsOffsets(lhsReductionDims.size(), 0); SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create(loc, newLhs, lhsOffsets); - newRhs = rewriter.create(loc, newRhs, rhsOffsets); + newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets); + newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets); std::optional result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); @@ -1039,8 +1040,8 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); @@ -1052,8 +1053,8 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, lowMask = reshapeLoad(loc, mask, cast(mask.getType()), iterIndex, d, rewriter); - Operation *lowContract = rewriter.create( - loc, lhs, rhs, acc, lowAffine, lowIter); + Operation *lowContract = vector::ContractionOp::create( + rewriter, loc, lhs, rhs, acc, lowAffine, lowIter); lowContract = maskOperation(rewriter, lowContract, lowMask); result = reshapeStore(loc, lowContract->getResult(0), result, resType, resIndex, d, rewriter); @@ -1103,8 +1104,8 @@ FailureOr ContractionOpLowering::lowerReduction( Value acc = op.getAcc(); Operation *reductionOp = - acc ? rewriter.create(loc, kind, m, acc) - : rewriter.create(loc, kind, m); + acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc) + : vector::ReductionOp::create(rewriter, loc, kind, m); return maskOperation(rewriter, reductionOp, mask)->getResult(0); } // Construct new iterator types and affine map array attribute. @@ -1128,8 +1129,8 @@ FailureOr ContractionOpLowering::lowerReduction( newMask = reshapeLoad(loc, mask, cast(mask.getType()), iterIndex, d, rewriter); - Operation *newContract = rewriter.create( - loc, lhs, rhs, result, lowAffine, lowIter); + Operation *newContract = vector::ContractionOp::create( + rewriter, loc, lhs, rhs, result, lowAffine, lowIter); result = maskOperation(rewriter, newContract, newMask)->getResult(0); } return result; @@ -1182,7 +1183,8 @@ class OuterProductOpLowering : public OpRewritePattern { if (!rhsType) { // Special case: AXPY operation. - Value b = rewriter.create(loc, lhsType, op.getRhs()); + Value b = + vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs()); std::optional mult = createContractArithOp( loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); if (!mult.has_value()) @@ -1191,23 +1193,23 @@ class OuterProductOpLowering : public OpRewritePattern { return success(); } - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - Value x = rewriter.create(loc, op.getLhs(), d); - Value a = rewriter.create(loc, rhsType, x); + Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d); + Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, acc, d); + r = vector::ExtractOp::create(rewriter, loc, acc, d); Value extrMask; if (mask) - extrMask = rewriter.create(loc, mask, d); + extrMask = vector::ExtractOp::create(rewriter, loc, mask, d); std::optional m = createContractArithOp( loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, *m, result, d); + result = vector::InsertOp::create(rewriter, loc, *m, result, d); } rewriter.replaceOp(rootOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index f4ad56b4178db..2484670c39caa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -68,8 +68,8 @@ struct UnrollGather : OpRewritePattern { Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = rewriter.create( - loc, resultTy, rewriter.getZeroAttr(resultTy)); + Value result = arith::ConstantOp::create(rewriter, loc, resultTy, + rewriter.getZeroAttr(resultTy)); VectorType subTy = VectorType::Builder(resultTy).dropDim(0); @@ -77,16 +77,16 @@ struct UnrollGather : OpRewritePattern { int64_t thisIdx[1] = {i}; Value indexSubVec = - rewriter.create(loc, indexVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); Value maskSubVec = - rewriter.create(loc, maskVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = - rewriter.create(loc, passThruVec, thisIdx); - Value subGather = rewriter.create( - loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, - passThruSubVec); + vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); + Value subGather = vector::GatherOp::create( + rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, + maskSubVec, passThruSubVec); result = - rewriter.create(loc, subGather, result, thisIdx); + vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); } rewriter.replaceOp(op, result); @@ -152,24 +152,24 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { // 1. Collapse the input memref so that it's "flat". SmallVector reassoc = {{0, 1}}; - Value collapsed = rewriter.create( - op.getLoc(), subview.getSource(), reassoc); + Value collapsed = memref::CollapseShapeOp::create( + rewriter, op.getLoc(), subview.getSource(), reassoc); // 2. Generate new gather indices that will model the // strided access. IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); VectorType vType = op.getIndexVec().getType(); - Value mulCst = rewriter.create( - op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); + Value mulCst = arith::ConstantOp::create( + rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); Value newIdxs = - rewriter.create(op.getLoc(), op.getIndexVec(), mulCst); + arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst); // 3. Create an updated gather op with the collapsed input memref and the // updated indices. - Value newGather = rewriter.create( - op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(), - newIdxs, op.getMask(), op.getPassThru()); + Value newGather = vector::GatherOp::create( + rewriter, op.getLoc(), op.getResult().getType(), collapsed, + op.getIndices(), newIdxs, op.getMask(), op.getPassThru()); rewriter.replaceOp(op, newGather); return success(); @@ -222,8 +222,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern { for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { int64_t thisIdx[1] = {i}; Value condition = - rewriter.create(loc, condMask, thisIdx); - Value index = rewriter.create(loc, indexVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, condMask, thisIdx); + Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); baseOffsets.back() = rewriter.createOrFold(loc, lastBaseOffset, index); @@ -233,19 +233,19 @@ struct Gather1DToConditionalLoads : OpRewritePattern { // `vector.load` does not support scalar result; emit a vector load // and extract the single result instead. Value load = - b.create(loc, elemVecTy, base, baseOffsets); + vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets); int64_t zeroIdx[1] = {0}; - extracted = b.create(loc, load, zeroIdx); + extracted = vector::ExtractOp::create(b, loc, load, zeroIdx); } else { - extracted = b.create(loc, base, baseOffsets); + extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets); } Value newResult = - b.create(loc, extracted, result, thisIdx); - b.create(loc, newResult); + vector::InsertOp::create(b, loc, extracted, result, thisIdx); + scf::YieldOp::create(b, loc, newResult); }; auto passThruBuilder = [result](OpBuilder &b, Location loc) { - b.create(loc, result); + scf::YieldOp::create(b, loc, result); }; result = diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index cab0f213b14a9..9d6a865a9301f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -60,14 +60,16 @@ class UnrollInterleaveOp final : public OpRewritePattern { return failure(); auto loc = op.getLoc(); - Value result = rewriter.create( - loc, resultType, rewriter.getZeroAttr(resultType)); + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); for (auto position : *unrollIterator) { - Value extractLhs = rewriter.create(loc, op.getLhs(), position); - Value extractRhs = rewriter.create(loc, op.getRhs(), position); + Value extractLhs = + ExtractOp::create(rewriter, loc, op.getLhs(), position); + Value extractRhs = + ExtractOp::create(rewriter, loc, op.getRhs(), position); Value interleave = - rewriter.create(loc, extractLhs, extractRhs); - result = rewriter.create(loc, interleave, result, position); + InterleaveOp::create(rewriter, loc, extractLhs, extractRhs); + result = InsertOp::create(rewriter, loc, interleave, result, position); } rewriter.replaceOp(op, result); @@ -123,20 +125,20 @@ class UnrollDeinterleaveOp final return failure(); auto loc = op.getLoc(); - Value emptyResult = rewriter.create( - loc, resultType, rewriter.getZeroAttr(resultType)); + Value emptyResult = arith::ConstantOp::create( + rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); Value evenResult = emptyResult; Value oddResult = emptyResult; for (auto position : *unrollIterator) { auto extractSrc = - rewriter.create(loc, op.getSource(), position); + vector::ExtractOp::create(rewriter, loc, op.getSource(), position); auto deinterleave = - rewriter.create(loc, extractSrc); - evenResult = rewriter.create( - loc, deinterleave.getRes1(), evenResult, position); - oddResult = rewriter.create(loc, deinterleave.getRes2(), - oddResult, position); + vector::DeinterleaveOp::create(rewriter, loc, extractSrc); + evenResult = vector::InsertOp::create( + rewriter, loc, deinterleave.getRes1(), evenResult, position); + oddResult = vector::InsertOp::create( + rewriter, loc, deinterleave.getRes2(), oddResult, position); } rewriter.replaceOp(op, ValueRange{evenResult, oddResult}); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index ba21092d2af3c..45ef7f01a85f1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -67,19 +67,20 @@ class CreateMaskOpLowering : public OpRewritePattern { Value idx = op.getOperand(0); VectorType lowType = VectorType::Builder(dstType).dropDim(0); - Value trueVal = rewriter.create( - loc, lowType, op.getOperands().drop_front()); - Value falseVal = rewriter.create( - loc, lowType, rewriter.getZeroAttr(lowType)); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); + Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType, + op.getOperands().drop_front()); + Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType, + rewriter.getZeroAttr(lowType)); + Value result = arith::ConstantOp::create(rewriter, loc, dstType, + rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < dim; d++) { Value bnd = - rewriter.create(loc, rewriter.getIndexAttr(d)); - Value val = rewriter.create(loc, arith::CmpIPredicate::slt, - bnd, idx); - Value sel = rewriter.create(loc, val, trueVal, falseVal); - result = rewriter.create(loc, sel, result, d); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(d)); + Value val = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, bnd, idx); + Value sel = + arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal); + result = vector::InsertOp::create(rewriter, loc, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -146,12 +147,12 @@ class ConstantMaskOpLowering : public OpRewritePattern { op, "Cannot unroll leading scalable dim in dstType"); VectorType lowType = VectorType::Builder(dstType).dropDim(0); - Value trueVal = rewriter.create( - loc, lowType, dimSizes.drop_front()); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); + Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType, + dimSizes.drop_front()); + Value result = arith::ConstantOp::create(rewriter, loc, dstType, + rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < trueDimSize; d++) - result = rewriter.create(loc, trueVal, result, d); + result = vector::InsertOp::create(rewriter, loc, trueVal, result, d); rewriter.replaceOp(op, result); return success(); @@ -261,8 +262,8 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern { PatternRewriter &rewriter) const override { Value passthru = maskingOp.hasPassthru() ? maskingOp.getPassthru() - : rewriter.create( - gatherOp.getLoc(), + : arith::ConstantOp::create( + rewriter, gatherOp.getLoc(), rewriter.getZeroAttr(gatherOp.getVectorType())); // Replace the `vector.mask` operation. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index ce524b259d8d4..4773732d8d9a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -103,12 +103,12 @@ class InnerOuterDimReductionConversion // If masked, transpose the original mask. Value transposedMask; if (maskableOp.isMasked()) { - transposedMask = rewriter.create( - loc, maskableOp.getMaskingOp().getMask(), indices); + transposedMask = vector::TransposeOp::create( + rewriter, loc, maskableOp.getMaskingOp().getMask(), indices); } // Transpose reduction source. - auto transposeOp = rewriter.create(loc, src, indices); + auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { if (useInnerDimsForReduction) @@ -117,8 +117,8 @@ class InnerOuterDimReductionConversion reductionMask[i] = true; } - Operation *newMultiRedOp = rewriter.create( - multiReductionOp.getLoc(), transposeOp.getResult(), + Operation *newMultiRedOp = vector::MultiDimReductionOp::create( + rewriter, multiReductionOp.getLoc(), transposeOp.getResult(), multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); newMultiRedOp = mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); @@ -255,15 +255,15 @@ class ReduceMultiDimReductionRank auto maskCastedType = VectorType::get( vectorShape, llvm::cast(vectorMask.getType()).getElementType()); - newVectorMask = - rewriter.create(loc, maskCastedType, vectorMask); + newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType, + vectorMask); } auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType(), scalableDims); - Value cast = rewriter.create( - loc, castedType, multiReductionOp.getSource()); + Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, + multiReductionOp.getSource()); Value acc = multiReductionOp.getAcc(); if (flattenedParallelDim) { @@ -271,12 +271,12 @@ class ReduceMultiDimReductionRank {flattenedParallelDim}, multiReductionOp.getSourceVectorType().getElementType(), /*scalableDims=*/{isParallelDimScalable}); - acc = rewriter.create(loc, accType, acc); + acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc); } // 6. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. - Operation *newMultiDimRedOp = rewriter.create( - loc, cast, acc, mask, multiReductionOp.getKind()); + Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create( + rewriter, loc, cast, acc, mask, multiReductionOp.getKind()); newMultiDimRedOp = mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); @@ -339,11 +339,11 @@ struct TwoDimMultiReductionToElementWise Value result = multiReductionOp.getAcc(); for (int64_t i = 0; i < srcShape[0]; i++) { - auto operand = rewriter.create( - loc, multiReductionOp.getSource(), i); + auto operand = vector::ExtractOp::create(rewriter, loc, + multiReductionOp.getSource(), i); Value extractMask = nullptr; if (mask) { - extractMask = rewriter.create(loc, mask, i); + extractMask = vector::ExtractOp::create(rewriter, loc, mask, i); } result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand, @@ -383,28 +383,29 @@ struct TwoDimMultiReductionToReduction } auto loc = multiReductionOp.getLoc(); - Value result = rewriter.create( - loc, multiReductionOp.getDestType(), + Value result = arith::ConstantOp::create( + rewriter, loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; for (int i = 0; i < outerDim; ++i) { - auto v = rewriter.create( - loc, multiReductionOp.getSource(), ArrayRef{i}); - auto acc = rewriter.create( - loc, multiReductionOp.getAcc(), ArrayRef{i}); - Operation *reductionOp = rewriter.create( - loc, multiReductionOp.getKind(), v, acc); + auto v = vector::ExtractOp::create( + rewriter, loc, multiReductionOp.getSource(), ArrayRef{i}); + auto acc = vector::ExtractOp::create( + rewriter, loc, multiReductionOp.getAcc(), ArrayRef{i}); + Operation *reductionOp = vector::ReductionOp::create( + rewriter, loc, multiReductionOp.getKind(), v, acc); // If masked, slice the mask and mask the new reduction operation. if (maskableOp.isMasked()) { - Value mask = rewriter.create( - loc, maskableOp.getMaskingOp().getMask(), ArrayRef{i}); + Value mask = vector::ExtractOp::create( + rewriter, loc, maskableOp.getMaskingOp().getMask(), + ArrayRef{i}); reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); } - result = rewriter.create(loc, reductionOp->getResult(0), - result, i); + result = vector::InsertOp::create(rewriter, loc, + reductionOp->getResult(0), result, i); } rewriter.replaceOp(rootOp, result); @@ -459,10 +460,10 @@ struct OneDimMultiReductionToTwoDim SmallVector reductionMask{false, true}; /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) - Value cast = rewriter.create( - loc, castedType, multiReductionOp.getSource()); - Value castAcc = rewriter.create( - loc, accType, multiReductionOp.getAcc()); + Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, + multiReductionOp.getSource()); + Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType, + multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { auto maskType = llvm::cast(mask.getType()); @@ -470,11 +471,12 @@ struct OneDimMultiReductionToTwoDim ArrayRef{1, maskType.getShape().back()}, maskType.getElementType(), ArrayRef{false, maskType.getScalableDims().back()}); - castMask = rewriter.create(loc, castMaskType, mask); + castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask); } - Operation *newOp = rewriter.create( - loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); + Operation *newOp = vector::MultiDimReductionOp::create( + rewriter, loc, cast, castAcc, reductionMask, + multiReductionOp.getKind()); newOp = vector::maskOperation(rewriter, newOp, castMask); rewriter.replaceOpWithNewOp(rootOp, newOp->getResult(0), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 6f3955f522775..af4851eb5f158 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -112,8 +112,8 @@ struct ScanToArithOps : public OpRewritePattern { return failure(); VectorType resType = VectorType::get(destShape, elType); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); bool inclusive = scanOp.getInclusive(); int64_t destRank = destType.getRank(); @@ -134,9 +134,9 @@ struct ScanToArithOps : public OpRewritePattern { for (int i = 0; i < destShape[reductionDim]; i++) { offsets[reductionDim] = i; ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); - Value input = rewriter.create( - loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, - scanStrides); + Value input = vector::ExtractStridedSliceOp::create( + rewriter, loc, reductionType, scanOp.getSource(), scanOffsets, + scanSizes, scanStrides); Value output; if (i == 0) { if (inclusive) { @@ -144,11 +144,11 @@ struct ScanToArithOps : public OpRewritePattern { } else { if (initialValueRank == 0) { // ShapeCastOp cannot handle 0-D vectors - output = rewriter.create( - loc, input.getType(), scanOp.getInitialValue()); + output = vector::BroadcastOp::create(rewriter, loc, input.getType(), + scanOp.getInitialValue()); } else { - output = rewriter.create( - loc, input.getType(), scanOp.getInitialValue()); + output = vector::ShapeCastOp::create(rewriter, loc, input.getType(), + scanOp.getInitialValue()); } } } else { @@ -156,20 +156,20 @@ struct ScanToArithOps : public OpRewritePattern { output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(), lastOutput, y); } - result = rewriter.create( - loc, output, result, offsets, strides); + result = vector::InsertStridedSliceOp::create(rewriter, loc, output, + result, offsets, strides); lastOutput = output; lastInput = input; } Value reduction; if (initialValueRank == 0) { - Value v = rewriter.create(loc, lastOutput, 0); + Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0); reduction = - rewriter.create(loc, initialValueType, v); + vector::BroadcastOp::create(rewriter, loc, initialValueType, v); } else { - reduction = rewriter.create(loc, initialValueType, - lastOutput); + reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType, + lastOutput); } rewriter.replaceOp(scanOp, {result, reduction}); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 39c16fab21c4e..603ea41d43360 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -137,11 +137,12 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { const int64_t resultLeading = delta > 0 ? 0 : -delta; const Value source = shapeCast.getSource(); - const Value poison = rewriter.create(loc, resultType); - const Value extracted = rewriter.create( - loc, source, SmallVector(sourceLeading, 0)); - const Value result = rewriter.create( - loc, extracted, poison, SmallVector(resultLeading, 0)); + const Value poison = ub::PoisonOp::create(rewriter, loc, resultType); + const Value extracted = vector::ExtractOp::create( + rewriter, loc, source, SmallVector(sourceLeading, 0)); + const Value result = + vector::InsertOp::create(rewriter, loc, extracted, poison, + SmallVector(resultLeading, 0)); rewriter.replaceOp(shapeCast, result); return success(); @@ -171,14 +172,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { SmallVector extractIndex(sourceDim, 0); SmallVector insertIndex(resultDim, 0); - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); for (int i = 0; i < nSlices; ++i) { Value extracted = - rewriter.create(loc, source, extractIndex); + vector::ExtractOp::create(rewriter, loc, source, extractIndex); - result = rewriter.create(loc, extracted, result, - insertIndex); + result = vector::InsertOp::create(rewriter, loc, extracted, result, + insertIndex); inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex); inplaceAdd(1, resultShape.take_front(resultDim), insertIndex); @@ -276,9 +277,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { Value extracted = {}; Value extractedStrided = {}; Value insertedSlice = {}; - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); const Value partResult = - rewriter.create(loc, insertStridedType); + ub::PoisonOp::create(rewriter, loc, insertStridedType); for (size_t i = 0; i < nAtomicSlices; ++i) { @@ -288,28 +289,28 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { // vector.extract if (extractStridedPhase == 0) { extracted = - rewriter.create(loc, source, extractIndex); + vector::ExtractOp::create(rewriter, loc, source, extractIndex); inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim), extractIndex); } // vector.extract_strided_slice extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; - extractedStrided = rewriter.create( - loc, extracted, extractOffsets, atomicShape, sizes); + extractedStrided = vector::ExtractStridedSliceOp::create( + rewriter, loc, extracted, extractOffsets, atomicShape, sizes); // vector.insert_strided_slice if (insertStridedPhase == 0) { insertedSlice = partResult; } insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; - insertedSlice = rewriter.create( - loc, extractedStrided, insertedSlice, insertOffsets, sizes); + insertedSlice = vector::InsertStridedSliceOp::create( + rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes); // vector.insert if (insertStridedPhase + 1 == insertPeriod) { - result = rewriter.create(loc, insertedSlice, result, - insertIndex); + result = vector::InsertOp::create(rewriter, loc, insertedSlice, result, + insertIndex); inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), insertIndex); } @@ -394,7 +395,7 @@ class ScalableShapeCastOpRewritePattern auto extractionVectorType = VectorType::get( {minExtractionSize}, sourceVectorType.getElementType(), {true}); - Value result = rewriter.create(loc, resultVectorType); + Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType); SmallVector srcIdx(srcRank, 0); SmallVector resIdx(resRank, 0); @@ -406,16 +407,18 @@ class ScalableShapeCastOpRewritePattern // 1. Extract a scalable subvector from the source vector. if (!currentSourceScalableVector) { if (srcRank != 1) { - currentSourceScalableVector = rewriter.create( - loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); + currentSourceScalableVector = + vector::ExtractOp::create(rewriter, loc, op.getSource(), + llvm::ArrayRef(srcIdx).drop_back()); } else { currentSourceScalableVector = op.getSource(); } } Value sourceSubVector = currentSourceScalableVector; if (minExtractionSize < minSourceTrailingSize) { - sourceSubVector = rewriter.create( - loc, extractionVectorType, sourceSubVector, srcIdx.back()); + sourceSubVector = vector::ScalableExtractOp::create( + rewriter, loc, extractionVectorType, sourceSubVector, + srcIdx.back()); } // 2. Insert the scalable subvector into the result vector. @@ -423,15 +426,16 @@ class ScalableShapeCastOpRewritePattern if (minExtractionSize == minResultTrailingSize) { currentResultScalableVector = sourceSubVector; } else if (resRank != 1) { - currentResultScalableVector = rewriter.create( - loc, result, llvm::ArrayRef(resIdx).drop_back()); + currentResultScalableVector = vector::ExtractOp::create( + rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back()); } else { currentResultScalableVector = result; } } if (minExtractionSize < minResultTrailingSize) { - currentResultScalableVector = rewriter.create( - loc, sourceSubVector, currentResultScalableVector, resIdx.back()); + currentResultScalableVector = vector::ScalableInsertOp::create( + rewriter, loc, sourceSubVector, currentResultScalableVector, + resIdx.back()); } // 3. Update the source and result scalable vectors if needed. @@ -439,9 +443,9 @@ class ScalableShapeCastOpRewritePattern currentResultScalableVector != result) { // Finished row of result. Insert complete scalable vector into result // (n-D) vector. - result = rewriter.create( - loc, currentResultScalableVector, result, - llvm::ArrayRef(resIdx).drop_back()); + result = vector::InsertOp::create(rewriter, loc, + currentResultScalableVector, result, + llvm::ArrayRef(resIdx).drop_back()); currentResultScalableVector = {}; } if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 475528289f01f..6407a868abd85 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -629,8 +629,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { nextLevelVectorSize); } - Value shuffleVal = rewriter.create( - loc, lhsVector, rhsVector, shuffleMask); + Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector, + rhsVector, shuffleMask); levelOutputs.push_back(shuffleVal); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index fb040bc51a993..e9109322ed3d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -44,7 +44,7 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, originalVecType.getScalableDims().end()); VectorType newVecType = VectorType::get( newShape, originalVecType.getElementType(), newScalableDims); - return builder.create(loc, newVecType, vec); + return vector::BroadcastOp::create(builder, loc, newVecType, vec); } /// Extend the rank of a vector Value by `addedRanks` by adding inner unit @@ -59,7 +59,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, permutation.push_back(i); for (int64_t i = 0; i < addedRank; ++i) permutation.push_back(i); - return builder.create(loc, broadcasted, permutation); + return vector::TransposeOp::create(builder, loc, broadcasted, permutation); } //===----------------------------------------------------------------------===// @@ -135,8 +135,8 @@ struct TransferReadPermutationLowering // Generate new transfer_read operation. VectorType newReadType = VectorType::get( newVectorShape, op.getVectorType().getElementType(), newScalableDims); - Value newRead = rewriter.create( - op.getLoc(), newReadType, op.getBase(), op.getIndices(), + Value newRead = vector::TransferReadOp::create( + rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); @@ -206,12 +206,12 @@ struct TransferWritePermutationLowering inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); // Generate new transfer_write operation. - Value newVec = rewriter.create( - op.getLoc(), op.getVector(), indices); + Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(), + op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); - auto newWrite = rewriter.create( - op.getLoc(), newVec, op.getBase(), op.getIndices(), + auto newWrite = vector::TransferWriteOp::create( + rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); @@ -296,8 +296,8 @@ struct TransferWriteNonPermutationLowering newInBoundsValues.push_back(op.isDimInBounds(i)); } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); - auto newWrite = rewriter.create( - op.getLoc(), newVec, op.getBase(), op.getIndices(), + auto newWrite = vector::TransferWriteOp::create( + rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), newMask, newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); @@ -367,8 +367,8 @@ struct TransferOpReduceRank ? rewriter.getArrayAttr( op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); - Value newRead = rewriter.create( - op.getLoc(), newReadType, op.getBase(), op.getIndices(), + Value newRead = vector::TransferReadOp::create( + rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); return rewriter @@ -468,21 +468,21 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.getPadding()); - res = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.getBase(), + Value fill = vector::SplatOp::create( + rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); + res = vector::MaskedLoadOp::create( + rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), read.getIndices(), read.getMask(), fill); } else { - res = rewriter.create(read.getLoc(), - unbroadcastedVectorType, - read.getBase(), read.getIndices()); + res = vector::LoadOp::create(rewriter, read.getLoc(), + unbroadcastedVectorType, read.getBase(), + read.getIndices()); } // Insert a broadcasting op if required. if (!broadcastedDims.empty()) - res = rewriter.create( - read.getLoc(), read.getVectorType(), res->getResult(0)); + res = vector::BroadcastOp::create( + rewriter, read.getLoc(), read.getVectorType(), res->getResult(0)); return res->getResult(0); } @@ -566,12 +566,12 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter.create( - write.getLoc(), write.getBase(), write.getIndices(), write.getMask(), - write.getVector()); + vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(), + write.getIndices(), write.getMask(), + write.getVector()); } else { - rewriter.create(write.getLoc(), write.getVector(), - write.getBase(), write.getIndices()); + vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(), + write.getBase(), write.getIndices()); } // There's no return value for StoreOps. Use Value() to signal success to // matchAndRewrite. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index bb9a6832146e8..e14f96e7eec59 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -79,8 +79,8 @@ getUnpackShufflePermFor128Lane(ArrayRef vals, int numBits) { static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( - v1, v2, + return vector::ShuffleOp::create( + b, v1, v2, getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); } @@ -93,8 +93,8 @@ static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( - v1, v2, + return vector::ShuffleOp::create( + b, v1, v2, getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, numBits)); } @@ -108,8 +108,8 @@ static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - auto shuffle = b.create( - v1, v2, + auto shuffle = vector::ShuffleOp::create( + b, v1, v2, getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); return shuffle; } @@ -123,8 +123,8 @@ static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( - v1, v2, + return vector::ShuffleOp::create( + b, v1, v2, getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, numBits)); } @@ -180,7 +180,7 @@ static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, appendToMask(0, b23); appendToMask(16, b45); appendToMask(16, b67); - return b.create(v1, v2, shuffleMask); + return vector::ShuffleOp::create(b, v1, v2, shuffleMask); } /// Lowers the value to a vector.shuffle op. The `source` is expected to be a @@ -191,7 +191,7 @@ static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { for (int64_t j = 0; j < n; ++j) for (int64_t i = 0; i < m; ++i) mask.push_back(i * n + j); - return b.create(source.getLoc(), source, source, mask); + return vector::ShuffleOp::create(b, source.getLoc(), source, source, mask); } /// Lowers the value to a sequence of vector.shuffle ops. The `source` is @@ -283,9 +283,9 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, auto reshInputType = VectorType::get( {m, n}, cast(source.getType()).getElementType()); - Value res = b.create(reshInputType); + Value res = ub::PoisonOp::create(b, reshInputType); for (int64_t i = 0; i < m; ++i) - res = b.create(vs[i], res, i); + res = vector::InsertOp::create(b, vs[i], res, i); return res; } @@ -343,7 +343,7 @@ class TransposeOpLowering : public OpRewritePattern { // of the leftmost transposed dimensions. We traverse every transpose // element using a linearized index that we delinearize to generate the // appropriate indices for the extract/insert operations. - Value result = rewriter.create(loc, resType); + Value result = ub::PoisonOp::create(rewriter, loc, resType); int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); for (int64_t linearIdx = 0; linearIdx < numTransposedElements; @@ -466,14 +466,14 @@ class TransposeOp2DToShuffleLowering Location loc = op.getLoc(); auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); - auto reshInput = rewriter.create(loc, flattenedType, - op.getVector()); + auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType, + op.getVector()); Value res; if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && m == 16 && n == 16) { reshInput = - rewriter.create(loc, reshInputType, reshInput); + vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput); res = transposeToShuffle16x16(rewriter, reshInput, m, n); } else { // Fallback to shuffle on 1D approach. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 436029c31e7f8..58e94ea00189f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -114,7 +114,7 @@ struct DistributedLoadStoreHelper { "preregistered sequential value."); // Scalar case can directly use memref.store. if (!isa(val.getType())) - return b.create(loc, val, buffer, zero); + return memref::StoreOp::create(b, loc, val, buffer, zero); // Vector case must use vector::TransferWriteOp which will later lower to // vector.store of memref.store depending on further lowerings. @@ -127,8 +127,8 @@ struct DistributedLoadStoreHelper { } } SmallVector inBounds(indices.size(), true); - return b.create( - loc, val, buffer, indices, + return vector::TransferWriteOp::create( + b, loc, val, buffer, indices, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -156,7 +156,7 @@ struct DistributedLoadStoreHelper { // Scalar case can directly use memref.store. if (!isa(type)) - return b.create(loc, buffer, zero); + return memref::LoadOp::create(b, loc, buffer, zero); // Other cases must be vector atm. // Vector case must use vector::TransferReadOp which will later lower to @@ -172,8 +172,9 @@ struct DistributedLoadStoreHelper { } } SmallVector inBounds(indices.size(), true); - return b.create( - loc, cast(type), buffer, indices, /*padding=*/std::nullopt, + return vector::TransferReadOp::create( + b, loc, cast(type), buffer, indices, + /*padding=*/std::nullopt, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -243,11 +244,11 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern { rewriter.setInsertionPoint(warpOp); // Step 1: Create scf.if op. - Value c0 = rewriter.create(loc, 0); - Value isLane0 = rewriter.create( - loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); - auto ifOp = rewriter.create(loc, isLane0, - /*withElseRegion=*/false); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value isLane0 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); + auto ifOp = scf::IfOp::create(rewriter, loc, isLane0, + /*withElseRegion=*/false); rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and @@ -325,7 +326,7 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern { // Step 7. Delete terminator and add empty scf.yield. rewriter.eraseOp(yieldOp); rewriter.setInsertionPointToEnd(ifOp.thenBlock()); - rewriter.create(yieldLoc); + scf::YieldOp::create(rewriter, yieldLoc); // Compute replacements for WarpOp results. rewriter.replaceOp(warpOp, replacements); @@ -512,8 +513,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); // Create a second warp op that contains only writeOp. - auto secondWarpOp = rewriter.create( - loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); + auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(), + newWarpOp.getLaneid(), + newWarpOp.getWarpSize()); Block &body = secondWarpOp.getBodyRegion().front(); rewriter.setInsertionPointToStart(&body); auto newWriteOp = @@ -521,7 +523,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { newWriteOp.getValueToStoreMutable().assign( newWarpOp.getResult(newRetIndices[0])); rewriter.eraseOp(writeOp); - rewriter.create(newWarpOp.getLoc()); + gpu::YieldOp::create(rewriter, newWarpOp.getLoc()); return success(); } @@ -698,7 +700,7 @@ struct WarpOpConstant : public WarpDistributionPattern { cast(warpOp.getResult(operandIndex).getType()), scalarAttr); Location loc = warpOp.getLoc(); rewriter.setInsertionPointAfter(warpOp); - Value distConstant = rewriter.create(loc, newAttr); + Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); rewriter.finalizeOpModification(warpOp); return success(); @@ -823,9 +825,9 @@ struct WarpOpTransferRead : public WarpDistributionPattern { Value newMask = hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) : Value(); - auto newRead = rewriter.create( - read.getLoc(), distributedVal.getType(), read.getBase(), newIndices, - read.getPermutationMapAttr(), newPadding, newMask, + auto newRead = vector::TransferReadOp::create( + rewriter, read.getLoc(), distributedVal.getType(), read.getBase(), + newIndices, read.getPermutationMapAttr(), newPadding, newMask, read.getInBoundsAttr()); rewriter.replaceAllUsesWith(distributedVal, newRead); @@ -965,8 +967,8 @@ struct WarpOpBroadcast : public WarpDistributionPattern { WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value broadcasted = rewriter.create( - loc, destVecType, newWarpOp->getResult(newRetIndices[0])); + Value broadcasted = vector::BroadcastOp::create( + rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), broadcasted); return success(); @@ -1008,8 +1010,8 @@ struct WarpOpShapeCast : public WarpDistributionPattern { rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value newCast = rewriter.create( - oldCastOp.getLoc(), castResultType, + Value newCast = vector::ShapeCastOp::create( + rewriter, oldCastOp.getLoc(), castResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); @@ -1091,7 +1093,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern { } auto newMask = - rewriter.create(loc, distType, newOperands); + vector::CreateMaskOp::create(rewriter, loc, distType, newOperands); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); rewriter.finalizeOpModification(warpOp); return success(); @@ -1182,9 +1184,10 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern { Value distributedDest = newWarpOp->getResult(newRetIndices[1]); // Create a new insert strided slice op that inserts distributed source into // distributed dest. - Value newInsert = rewriter.create( - insertOp.getLoc(), distributedDest.getType(), distributedSource, - distributedDest, insertOp.getOffsets(), insertOp.getStrides()); + Value newInsert = vector::InsertStridedSliceOp::create( + rewriter, insertOp.getLoc(), distributedDest.getType(), + distributedSource, distributedDest, insertOp.getOffsets(), + insertOp.getStrides()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); return success(); } @@ -1277,8 +1280,8 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern { // Create a new extract strided slice op that extracts from the // distributed vector. Value distributedVec = newWarpOp->getResult(newRetIndices[0]); - Value newExtract = rewriter.create( - extractOp.getLoc(), distributedType, distributedVec, + Value newExtract = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), distributedType, distributedVec, extractOp.getOffsets(), ArrayAttr::get(rewriter.getContext(), distributedSizes), extractOp.getStrides()); @@ -1323,8 +1326,8 @@ struct WarpOpExtract : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. - Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getMixedPosition()); + Value newExtract = vector::ExtractOp::create( + rewriter, loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1352,8 +1355,8 @@ struct WarpOpExtract : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. - Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getMixedPosition()); + Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec, + extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1422,7 +1425,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { Value newExtract; SmallVector indices(extractSrcType.getRank(), 0); newExtract = - rewriter.create(loc, distributedVec, indices); + vector::ExtractOp::create(rewriter, loc, distributedVec, indices); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1442,11 +1445,11 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { // Extract at position: pos % elementsPerLane Value newPos = elementsPerLane == 1 - ? rewriter.create(loc, 0).getResult() + ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult() : affine::makeComposedAffineApply(rewriter, loc, sym0 % elementsPerLane, pos); Value extracted = - rewriter.create(loc, distributedVec, newPos); + vector::ExtractOp::create(rewriter, loc, distributedVec, newPos); // Shuffle the extracted value to all lanes. Value shuffled = warpShuffleFromIdxFn( @@ -1514,8 +1517,8 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { if (pos) { indices.push_back(pos); } - newInsert = rewriter.create(loc, newSource, - distributedVec, indices); + newInsert = vector::InsertOp::create(rewriter, loc, newSource, + distributedVec, indices); // Broadcast: Simply move the vector.insert op out. rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); @@ -1531,21 +1534,22 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { // Insert position: pos % elementsPerLane OpFoldResult newPos = affine::makeComposedFoldedAffineApply( rewriter, loc, sym0 % elementsPerLane, pos); - Value isInsertingLane = rewriter.create( - loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); + Value isInsertingLane = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + newWarpOp.getLaneid(), insertingLane); Value newResult = rewriter .create( loc, isInsertingLane, /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { - Value newInsert = builder.create( - loc, newSource, distributedVec, newPos); - builder.create(loc, newInsert); + Value newInsert = vector::InsertOp::create( + builder, loc, newSource, distributedVec, newPos); + scf::YieldOp::create(builder, loc, newInsert); }, /*elseBuilder=*/ [&](OpBuilder &builder, Location loc) { - builder.create(loc, distributedVec); + scf::YieldOp::create(builder, loc, distributedVec); }) .getResult(0); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); @@ -1582,8 +1586,9 @@ struct WarpOpInsert : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); Value distributedDest = newWarpOp->getResult(newRetIndices[1]); - Value newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); + Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc, + distributedDest, + insertOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); @@ -1632,27 +1637,29 @@ struct WarpOpInsert : public WarpDistributionPattern { Value newResult; if (distrSrcDim >= 0) { // Every lane inserts a small piece. - newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); + newResult = vector::InsertOp::create(rewriter, loc, distributedSrc, + distributedDest, + insertOp.getMixedPosition()); } else { // One lane inserts the entire source vector. int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); SmallVector pos = insertOp.getMixedPosition(); SmallVector newPos = getAsIntegers(pos); // tid of inserting lane: pos / elementsPerLane - Value insertingLane = rewriter.create( - loc, newPos[distrDestDim] / elementsPerLane); - Value isInsertingLane = rewriter.create( - loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); + Value insertingLane = arith::ConstantIndexOp::create( + rewriter, loc, newPos[distrDestDim] / elementsPerLane); + Value isInsertingLane = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + newWarpOp.getLaneid(), insertingLane); // Insert position: pos % elementsPerLane newPos[distrDestDim] %= elementsPerLane; auto insertingBuilder = [&](OpBuilder &builder, Location loc) { - Value newInsert = builder.create( - loc, distributedSrc, distributedDest, newPos); - builder.create(loc, newInsert); + Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc, + distributedDest, newPos); + scf::YieldOp::create(builder, loc, newInsert); }; auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { - builder.create(loc, distributedDest); + scf::YieldOp::create(builder, loc, distributedDest); }; newResult = rewriter .create(loc, isInsertingLane, @@ -1820,8 +1827,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); - auto newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + auto newForOp = scf::ForOp::create( + rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newForOpOperands); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were @@ -1845,9 +1852,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueInputTypes[i - escapingValuesStartIdx]); } // Create the inner `WarpOp` with the new input values and types. - auto innerWarp = rewriter.create( - newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), - newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); + auto innerWarp = WarpExecuteOnLane0Op::create( + rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(), + newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput, + innerWarpInputType); // Inline the `ForOp` body into the inner `WarpOp` body. SmallVector argMapping; @@ -1866,12 +1874,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields // original `ForOp` results. rewriter.setInsertionPointToEnd(innerWarp.getBody()); - rewriter.create(innerWarp.getLoc(), yieldOperands); + gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); // Insert a scf.yield op at the end of the new `ForOp` body that yields // the inner `WarpOp` results. if (!innerWarp.getResults().empty()) - rewriter.create(forOp.getLoc(), innerWarp.getResults()); + scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); // Update the users of original `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 067d4e3491391..73388a5da3e4f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -77,8 +77,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim Location loc = extractOp.getLoc(); - Value newSrcVector = rewriter.create( - loc, extractOp.getVector(), splatZero(dropCount)); + Value newSrcVector = vector::ExtractOp::create( + rewriter, loc, extractOp.getVector(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. @@ -89,8 +89,9 @@ struct CastAwayExtractStridedSliceLeadingOneDim auto newStrides = rewriter.getArrayAttr( extractOp.getStrides().getValue().drop_front(dropCount)); - auto newExtractOp = rewriter.create( - loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); + auto newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes, + newStrides); rewriter.replaceOpWithNewOp(extractOp, oldDstType, newExtractOp); @@ -120,18 +121,19 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Trim leading one dimensions from both operands. Location loc = insertOp.getLoc(); - Value newSrcVector = rewriter.create( - loc, insertOp.getValueToStore(), splatZero(srcDropCount)); - Value newDstVector = rewriter.create( - loc, insertOp.getDest(), splatZero(dstDropCount)); + Value newSrcVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); + Value newDstVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); auto newOffsets = rewriter.getArrayAttr( insertOp.getOffsets().getValue().take_back(newDstType.getRank())); auto newStrides = rewriter.getArrayAttr( insertOp.getStrides().getValue().take_back(newSrcType.getRank())); - auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); + auto newInsertOp = vector::InsertStridedSliceOp::create( + rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets, + newStrides); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); @@ -169,11 +171,11 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { Value newSrcVector = insertOp.getValueToStore(); if (oldSrcRank != 0) { - newSrcVector = rewriter.create( - loc, insertOp.getValueToStore(), splatZero(srcDropCount)); + newSrcVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); } - Value newDstVector = rewriter.create( - loc, insertOp.getDest(), splatZero(dstDropCount)); + Value newDstVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); // New position rank needs to be computed in two steps: (1) if destination // type has leading unit dims, we also trim the position array accordingly, @@ -187,8 +189,8 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { newPosition.resize(newDstType.getRank() - newSrcRank, rewriter.getI64IntegerAttr(0)); - auto newInsertOp = rewriter.create( - loc, newSrcVector, newDstVector, newPosition); + auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector, + newDstVector, newPosition); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); @@ -209,9 +211,9 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, if (vector::isBroadcastableTo(newMaskType, oldMaskType) == BroadcastableToResult::Success) { int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); - return b.create(loc, mask, splatZero(dropDim)); + return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim)); } - return b.create(loc, newMaskType, mask); + return vector::ShapeCastOp::create(b, loc, newMaskType, mask); } // Turns vector.transfer_read on vector with leading 1 dimensions into @@ -259,8 +261,8 @@ struct CastAwayTransferReadLeadingOneDim newType, newMap, maskType); } - auto newRead = rewriter.create( - read.getLoc(), newType, read.getBase(), read.getIndices(), + auto newRead = vector::TransferReadOp::create( + rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(), AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); @@ -306,8 +308,8 @@ struct CastAwayTransferWriteLeadingOneDim inBoundsAttr = rewriter.getArrayAttr( write.getInBoundsAttr().getValue().take_back(newType.getRank())); - auto newVector = rewriter.create( - write.getLoc(), write.getVector(), splatZero(dropDim)); + auto newVector = vector::ExtractOp::create( + rewriter, write.getLoc(), write.getVector(), splatZero(dropDim)); if (write.getMask()) { VectorType maskType = write.getMaskType(); @@ -443,22 +445,23 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, contractOp.getContext())); // Extract if its a valid extraction, otherwise use the operand // without extraction. - newOperands.push_back( - validExtract ? rewriter.create( - loc, operands[it.index()], splatZero(dropDim)) - : operands[it.index()]); + newOperands.push_back(validExtract + ? vector::ExtractOp::create(rewriter, loc, + operands[it.index()], + splatZero(dropDim)) + : operands[it.index()]); } // Depending on whether this vector.contract is masked, the replacing Op // should either be a new vector.contract Op or vector.mask Op. - Operation *newOp = rewriter.create( - loc, newOperands[0], newOperands[1], newOperands[2], + Operation *newOp = vector::ContractionOp::create( + rewriter, loc, newOperands[0], newOperands[1], newOperands[2], rewriter.getAffineMapArrayAttr(newIndexingMaps), rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); if (maskingOp) { - auto newMask = rewriter.create(loc, maskingOp.getMask(), - splatZero(dropDim)); + auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(), + splatZero(dropDim)); newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); } @@ -519,8 +522,8 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { SmallVector newOperands; for (Value operand : op->getOperands()) { if (auto opVecType = dyn_cast(operand.getType())) { - newOperands.push_back(rewriter.create( - op->getLoc(), operand, splatZero(dropDim))); + newOperands.push_back(vector::ExtractOp::create( + rewriter, op->getLoc(), operand, splatZero(dropDim))); } else { newOperands.push_back(operand); } @@ -559,8 +562,8 @@ struct CastAwayConstantMaskLeadingOneDim SmallVector newDimSizes = {flatLeadingSize}; newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); - auto newMask = rewriter.create( - mask.getLoc(), newType, newDimSizes); + auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(), + newType, newDimSizes); rewriter.replaceOpWithNewOp(mask, oldType, newMask); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index 8cc7008d80b3e..cb3e8dc67a1ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -65,26 +65,27 @@ struct VectorMaskedLoadOpConverter final Value base = maskedLoadOp.getBase(); Value iValue = maskedLoadOp.getPassThru(); auto indices = llvm::to_vector_of(maskedLoadOp.getIndices()); - Value one = rewriter.create( - loc, indexType, IntegerAttr::get(indexType, 1)); + Value one = arith::ConstantOp::create(rewriter, loc, indexType, + IntegerAttr::get(indexType, 1)); for (int64_t i = 0; i < maskLength; ++i) { - auto maskBit = rewriter.create(loc, mask, i); + auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i); - auto ifOp = rewriter.create( - loc, maskBit, + auto ifOp = scf::IfOp::create( + rewriter, loc, maskBit, [&](OpBuilder &builder, Location loc) { auto loadedValue = - builder.create(loc, base, indices); + memref::LoadOp::create(builder, loc, base, indices); auto combinedValue = - builder.create(loc, loadedValue, iValue, i); - builder.create(loc, combinedValue.getResult()); + vector::InsertOp::create(builder, loc, loadedValue, iValue, i); + scf::YieldOp::create(builder, loc, combinedValue.getResult()); }, [&](OpBuilder &builder, Location loc) { - builder.create(loc, iValue); + scf::YieldOp::create(builder, loc, iValue); }); iValue = ifOp.getResult(0); - indices.back() = rewriter.create(loc, indices.back(), one); + indices.back() = + arith::AddIOp::create(rewriter, loc, indices.back(), one); } rewriter.replaceOp(maskedLoadOp, iValue); @@ -132,18 +133,19 @@ struct VectorMaskedStoreOpConverter final Value base = maskedStoreOp.getBase(); Value value = maskedStoreOp.getValueToStore(); auto indices = llvm::to_vector_of(maskedStoreOp.getIndices()); - Value one = rewriter.create( - loc, indexType, IntegerAttr::get(indexType, 1)); + Value one = arith::ConstantOp::create(rewriter, loc, indexType, + IntegerAttr::get(indexType, 1)); for (int64_t i = 0; i < maskLength; ++i) { - auto maskBit = rewriter.create(loc, mask, i); + auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i); - auto ifOp = rewriter.create(loc, maskBit, /*else=*/false); + auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - auto extractedValue = rewriter.create(loc, value, i); - rewriter.create(loc, extractedValue, base, indices); + auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i); + memref::StoreOp::create(rewriter, loc, extractedValue, base, indices); rewriter.setInsertionPointAfter(ifOp); - indices.back() = rewriter.create(loc, indices.back(), one); + indices.back() = + arith::AddIOp::create(rewriter, loc, indices.back(), one); } rewriter.eraseOp(maskedStoreOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 0fe08417f818f..e6bb96f453fbc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -132,8 +132,8 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, SmallVector newMaskOperands(maskOperands.drop_back()); newMaskOperands.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); - return rewriter.create(loc, newMaskType, - newMaskOperands); + return vector::CreateMaskOp::create(rewriter, loc, newMaskType, + newMaskOperands); }) .Case( [&](auto constantMaskOp) -> std::optional { @@ -143,8 +143,8 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, int64_t &maskIndex = maskDimSizes.back(); maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, numSrcElemsPerDest); - return rewriter.create(loc, newMaskType, - maskDimSizes); + return vector::ConstantMaskOp::create( + rewriter, loc, newMaskType, maskDimSizes); }) .Case([&](auto constantOp) -> std::optional { @@ -182,16 +182,18 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, } compressedMaskValues.push_back(combinedValue); } - return rewriter.create( - loc, DenseElementsAttr::get(newMaskType, compressedMaskValues)); + return arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(newMaskType, compressedMaskValues)); }); if (!newMask) return failure(); while (!extractOps.empty()) { - newMask = rewriter.create( - loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition()); + newMask = + vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0], + extractOps.back().getMixedPosition()); extractOps.pop_back(); } @@ -258,8 +260,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); - return rewriter.create(loc, destVecTy, src, - dest, offsets, strides); + return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src, + dest, offsets, strides); } /// Extracts 1-D subvector from a 1-D vector. @@ -301,11 +303,12 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, for (int i = 0; i < numElemsToExtract; ++i) { Value extractLoc = (i == 0) ? dyn_cast(offset) - : rewriter.create( - loc, rewriter.getIndexType(), dyn_cast(offset), - rewriter.create(loc, i)); - auto extractOp = rewriter.create(loc, src, extractLoc); - dest = rewriter.create(loc, extractOp, dest, i); + : arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), + dyn_cast(offset), + arith::ConstantIndexOp::create(rewriter, loc, i)); + auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc); + dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i); } return dest; } @@ -344,13 +347,13 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset); for (int64_t i = 0; i < numElemsToInsert; ++i) { - auto insertLoc = i == 0 - ? destOffsetVal - : rewriter.create( - loc, rewriter.getIndexType(), destOffsetVal, - rewriter.create(loc, i)); - auto extractOp = rewriter.create(loc, src, i); - dest = rewriter.create(loc, extractOp, dest, insertLoc); + auto insertLoc = + i == 0 ? destOffsetVal + : arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), destOffsetVal, + arith::ConstantIndexOp::create(rewriter, loc, i)); + auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i); + dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc); } return dest; } @@ -369,11 +372,11 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Type containerElemTy) { auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() / emulatedElemTy.getIntOrFloatBitWidth(); - auto newLoad = rewriter.create( - loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base, - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); - return rewriter.create( - loc, + auto newLoad = vector::LoadOp::create( + rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy), + base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + return vector::BitCastOp::create( + rewriter, loc, VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem, emulatedElemTy), newLoad); @@ -390,16 +393,17 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, upcastType.getNumElements() * upcastType.getElementTypeBitWidth() && "expected input and output number of bits to match"); if (trueValue.getType() != downcastType) { - trueValue = builder.create(loc, downcastType, trueValue); + trueValue = + vector::BitCastOp::create(builder, loc, downcastType, trueValue); } if (falseValue.getType() != downcastType) { falseValue = - builder.create(loc, downcastType, falseValue); + vector::BitCastOp::create(builder, loc, downcastType, falseValue); } Value selectedType = - builder.create(loc, mask, trueValue, falseValue); + arith::SelectOp::create(builder, loc, mask, trueValue, falseValue); // Upcast the selected value to the new type. - return builder.create(loc, upcastType, selectedType); + return vector::BitCastOp::create(builder, loc, upcastType, selectedType); } /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a @@ -422,8 +426,8 @@ static void atomicRMW(OpBuilder &builder, Location loc, // Create an atomic load-modify-write region using // `memref.generic_atomic_rmw`. - auto atomicOp = builder.create( - loc, linearizedMemref, ValueRange{storeIdx}); + auto atomicOp = memref::GenericAtomicRMWOp::create( + builder, loc, linearizedMemref, ValueRange{storeIdx}); Value origValue = atomicOp.getCurrentValue(); OpBuilder::InsertionGuard guard(builder); @@ -432,16 +436,16 @@ static void atomicRMW(OpBuilder &builder, Location loc, // Load the original value from memory, and cast it to the original element // type. auto oneElemVecType = VectorType::get({1}, origValue.getType()); - Value origVecValue = builder.create( - loc, oneElemVecType, ValueRange{origValue}); + Value origVecValue = vector::FromElementsOp::create( + builder, loc, oneElemVecType, ValueRange{origValue}); // Construct the final masked value and yield it. Value maskedValue = downcastSelectAndUpcast(builder, loc, valueToStore.getType(), oneElemVecType, mask, valueToStore, origVecValue); auto scalarMaskedValue = - builder.create(loc, maskedValue, 0); - builder.create(loc, scalarMaskedValue); + vector::ExtractOp::create(builder, loc, maskedValue, 0); + memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue); } /// Generate a non-atomic read-modify-write sequence for storing to the emulated @@ -453,16 +457,17 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc, auto oneElemVecType = VectorType::get({1}, linearizedMemref.getType().getElementType()); - Value origVecValue = builder.create( - loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex}); - origVecValue = builder.create(loc, valueToStore.getType(), - origVecValue); + Value origVecValue = + vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref, + ValueRange{linearizedIndex}); + origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(), + origVecValue); Value maskedValue = downcastSelectAndUpcast(builder, loc, valueToStore.getType(), oneElemVecType, mask, valueToStore, origVecValue); - builder.create(loc, maskedValue, linearizedMemref, - linearizedIndex); + vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref, + linearizedIndex); } /// Extract `sliceNumElements` from source `vector` at `extractOffset`, @@ -489,8 +494,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && "vector element must be a valid sub-byte type"); auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth(); - auto emptyByteVector = rewriter.create( - loc, VectorType::get({emulatedPerContainerElem}, vectorElementType), + auto emptyByteVector = arith::ConstantOp::create( + rewriter, loc, + VectorType::get({emulatedPerContainerElem}, vectorElementType), rewriter.getZeroAttr( VectorType::get({emulatedPerContainerElem}, vectorElementType))); auto extracted = staticallyExtractSubvector(rewriter, loc, vector, @@ -602,7 +608,7 @@ struct ConvertVectorStore final : OpConversionPattern { ShapedType::isDynamic(trailingDim) || trailingDim == origElements; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); // FIXME: ATM, we do not test cases where offsets, sizes, or strides are // non-zero. As such, this is not needed. @@ -664,8 +670,8 @@ struct ConvertVectorStore final : OpConversionPattern { if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; - auto bitCast = rewriter.create( - loc, VectorType::get(numElements, containerElemTy), + auto bitCast = vector::BitCastOp::create( + rewriter, loc, VectorType::get(numElements, containerElemTy), op.getValueToStore()); rewriter.replaceOpWithNewOp( op, bitCast.getResult(), memrefBase, @@ -732,8 +738,9 @@ struct ConvertVectorStore final : OpConversionPattern { std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem, *foldedNumFrontPadElems, true); } - auto frontMask = rewriter.create( - loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); + auto frontMask = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems); auto value = @@ -751,9 +758,9 @@ struct ConvertVectorStore final : OpConversionPattern { // Increment the destination index by 1 to align to the emulated width // boundary. - auto constantOne = rewriter.create(loc, 1); - currentDestIndex = rewriter.create( - loc, rewriter.getIndexType(), currentDestIndex, constantOne); + auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1); + currentDestIndex = arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne); // 2. Full width store for the inner output bytes. // After the previous step, the store address is aligned to the emulated @@ -772,15 +779,15 @@ struct ConvertVectorStore final : OpConversionPattern { auto storeType = VectorType::get( {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType); - auto bitCast = rewriter.create(loc, storeType, - fullWidthStorePart); - rewriter.create(loc, bitCast.getResult(), memrefBase, - currentDestIndex); + auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType, + fullWidthStorePart); + vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase, + currentDestIndex); currentSourceIndex += numNonFullWidthElements; - currentDestIndex = rewriter.create( - loc, rewriter.getIndexType(), currentDestIndex, - rewriter.create(loc, fullWidthStoreSize)); + currentDestIndex = arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), currentDestIndex, + arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize)); } // 3. Partial width store for the trailing output byte. @@ -795,8 +802,9 @@ struct ConvertVectorStore final : OpConversionPattern { // Generate back mask. auto maskValues = SmallVector(emulatedPerContainerElem, 0); std::fill_n(maskValues.begin(), remainingElements, 1); - auto backMask = rewriter.create( - loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); + auto backMask = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); storeFunc(rewriter, loc, memrefBase, currentDestIndex, cast(subWidthStorePart), backMask.getResult()); @@ -848,7 +856,7 @@ struct ConvertVectorMaskedStore final return failure(); auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndicesOfr; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndicesOfr) = @@ -901,21 +909,21 @@ struct ConvertVectorMaskedStore final auto numElements = (origElements + emulatedPerContainerElem - 1) / emulatedPerContainerElem; auto newType = VectorType::get(numElements, containerElemTy); - auto passThru = rewriter.create( - loc, newType, rewriter.getZeroAttr(newType)); + auto passThru = arith::ConstantOp::create(rewriter, loc, newType, + rewriter.getZeroAttr(newType)); - auto newLoad = rewriter.create( - loc, newType, adaptor.getBase(), linearizedIndices, + auto newLoad = vector::MaskedLoadOp::create( + rewriter, loc, newType, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), passThru); auto newBitCastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); Value valueToStore = - rewriter.create(loc, newBitCastType, newLoad); - valueToStore = rewriter.create( - loc, op.getMask(), op.getValueToStore(), valueToStore); + vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad); + valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(), + op.getValueToStore(), valueToStore); valueToStore = - rewriter.create(loc, newType, valueToStore); + vector::BitCastOp::create(rewriter, loc, newType, valueToStore); rewriter.replaceOpWithNewOp( op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), @@ -990,7 +998,7 @@ struct ConvertVectorLoad final : OpConversionPattern { bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; @@ -1016,8 +1024,8 @@ struct ConvertVectorLoad final : OpConversionPattern { numElements, emulatedElemTy, containerElemTy); if (!foldedIntraVectorOffset) { - auto resultVector = rewriter.create( - loc, op.getType(), rewriter.getZeroAttr(op.getType())); + auto resultVector = arith::ConstantOp::create( + rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); @@ -1111,7 +1119,7 @@ struct ConvertVectorMaskedLoad final bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = @@ -1142,8 +1150,8 @@ struct ConvertVectorMaskedLoad final auto newBitcastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); - auto emptyVector = rewriter.create( - loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); + auto emptyVector = arith::ConstantOp::create( + rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); if (!foldedIntraVectorOffset) { passthru = dynamicallyInsertSubVector( rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, @@ -1153,25 +1161,26 @@ struct ConvertVectorMaskedLoad final *foldedIntraVectorOffset); } auto newPassThru = - rewriter.create(loc, loadType, passthru); + vector::BitCastOp::create(rewriter, loc, loadType, passthru); // Generating the new masked load. - auto newLoad = rewriter.create( - loc, loadType, adaptor.getBase(), + auto newLoad = vector::MaskedLoadOp::create( + rewriter, loc, loadType, adaptor.getBase(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newMask.value()->getResult(0), newPassThru); // Setting the part that originally was not effectively loaded from memory // to pass through. auto bitCast = - rewriter.create(loc, newBitcastType, newLoad); + vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad); Value mask = op.getMask(); auto newSelectMaskType = VectorType::get( numElements * emulatedPerContainerElem, rewriter.getI1Type()); // TODO: try to fold if op's mask is constant - auto emptyMask = rewriter.create( - loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); + auto emptyMask = + arith::ConstantOp::create(rewriter, loc, newSelectMaskType, + rewriter.getZeroAttr(newSelectMaskType)); if (!foldedIntraVectorOffset) { mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, linearizedInfo.intraDataOffset, @@ -1182,7 +1191,7 @@ struct ConvertVectorMaskedLoad final } Value result = - rewriter.create(loc, mask, bitCast, passthru); + arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru); if (!foldedIntraVectorOffset) { result = dynamicallyExtractSubVector( rewriter, loc, result, op.getPassThru(), @@ -1272,17 +1281,17 @@ struct ConvertVectorTransferRead final // thus their values don't matter. Value padding = adaptor.getPadding(); if (!padding.getType().isInteger()) { - padding = rewriter.create( - loc, + padding = arith::BitcastOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), padding.getType().getIntOrFloatBitWidth()), padding); } auto newPadding = - rewriter.create(loc, containerElemTy, padding); + arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding); auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; @@ -1303,20 +1312,21 @@ struct ConvertVectorTransferRead final auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, emulatedPerContainerElem); - auto newRead = rewriter.create( - loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(), + auto newRead = vector::TransferReadOp::create( + rewriter, loc, VectorType::get(numElements, containerElemTy), + adaptor.getBase(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); - auto bitCast = rewriter.create( - loc, + auto bitCast = vector::BitCastOp::create( + rewriter, loc, VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy), newRead); Value result = bitCast->getResult(0); if (!foldedIntraVectorOffset) { - auto zeros = rewriter.create( - loc, op.getType(), rewriter.getZeroAttr(op.getType())); + auto zeros = arith::ConstantOp::create( + rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); @@ -1689,32 +1699,33 @@ Value BitCastRewriter::genericRewriteStep( PatternRewriter &rewriter, Location loc, Value initialValue, Value runningResult, const BitCastRewriter::Metadata &metadata) { // Create vector.shuffle from the metadata. - auto shuffleOp = rewriter.create( - loc, initialValue, initialValue, metadata.shuffles); + auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue, + initialValue, metadata.shuffles); // Intersect with the mask. VectorType shuffledVectorType = shuffleOp.getResultVectorType(); - auto constOp = rewriter.create( - loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks)); - Value andValue = rewriter.create(loc, shuffleOp, constOp); + auto constOp = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(shuffledVectorType, metadata.masks)); + Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp); // Align right on 0. - auto shiftRightConstantOp = rewriter.create( - loc, + auto shiftRightConstantOp = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts)); Value shiftedRight = - rewriter.create(loc, andValue, shiftRightConstantOp); + arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp); // Shift bits left into their final position. - auto shiftLeftConstantOp = rewriter.create( - loc, + auto shiftLeftConstantOp = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts)); Value shiftedLeft = - rewriter.create(loc, shiftedRight, shiftLeftConstantOp); + arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp); runningResult = runningResult - ? rewriter.create(loc, runningResult, shiftedLeft) + ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft) : shiftedLeft; return runningResult; @@ -1737,7 +1748,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, // Adjust last dimension of the vector, so the total size remains the same. vecShape.back() = vecShape.back() / numSrcElemsPerByte; auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type()); - return rewriter.create(loc, i8VecType, subByteVec); + return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec); } /// Extracts a signed N-bit sequence from each element of a vector of bytes, @@ -1765,15 +1776,15 @@ static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 && "Invalid bitIdx range"); if (bitsToShiftLeft != 0) { - Value shiftLeftValues = rewriter.create( - loc, DenseElementsAttr::get(srcType, bitsToShiftLeft)); - shl = rewriter.create(loc, src, shiftLeftValues); + Value shiftLeftValues = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft)); + shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues); } int8_t bitsToShiftRight = 8 - numBits; - Value shiftRightValues = rewriter.create( - loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); - Value shr = rewriter.create(loc, shl, shiftRightValues); + Value shiftRightValues = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); + Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues); return shr; } @@ -1807,17 +1818,17 @@ static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, int8_t bitsToShiftRight = bitIdx; Value shr = src; if (bitsToShiftRight != 0) { - Value shiftRightValues = rewriter.create( - loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); - shr = rewriter.create(loc, src, shiftRightValues); + Value shiftRightValues = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); + shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues); } if (bitIdx + numBits == 8) { return shr; } uint8_t lowBitsMask = (1 << numBits) - 1; - Value lowBitsMaskValues = rewriter.create( - loc, DenseElementsAttr::get(srcType, lowBitsMask)); - return rewriter.create(loc, shr, lowBitsMaskValues); + Value lowBitsMaskValues = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask)); + return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues); } using ExtractNBitsFn = @@ -1840,7 +1851,7 @@ static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value high = extFn(rewriter, loc, i8Vector, 4, 4); // 3. Interleave low and high i8 elements. - return rewriter.create(loc, low, high); + return vector::InterleaveOp::create(rewriter, loc, low, high); } /// Rewrite the i2 -> i8 extension into a sequence of shuffles and @@ -1873,9 +1884,10 @@ static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, // 02 = [0,2,0,2,0,2,0,2],... // 13 = [1,3,1,3,1,3,1,3],... // 0213 = [0,1,2,3,...],... - Value interleave02 = rewriter.create(loc, vec0, vec2); - Value interleave13 = rewriter.create(loc, vec1, vec3); - return rewriter.create(loc, interleave02, interleave13); + Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2); + Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3); + return vector::InterleaveOp::create(rewriter, loc, interleave02, + interleave13); } /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise @@ -1887,29 +1899,29 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, "Expected i8 type"); // 1. De-interleave low and high i8 elements. - auto deinterleaveOp = rewriter.create(loc, srcValue); + auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue); // 2. Zero out the upper side of each low i8 element. constexpr int8_t i8LowBitMask = 0x0F; VectorType deinterI8VecType = deinterleaveOp.getResultVectorType(); - Value zeroOutMask = rewriter.create( - loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask)); - Value zeroOutLow = rewriter.create( - loc, deinterleaveOp.getRes1(), zeroOutMask); + Value zeroOutMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask)); + Value zeroOutLow = arith::AndIOp::create( + rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask); // 3. Move high i4 values to upper side of the byte. constexpr int8_t bitsToShift = 4; - auto shiftValues = rewriter.create( - loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift)); - Value shlHigh = rewriter.create(loc, deinterleaveOp.getRes2(), - shiftValues); + auto shiftValues = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift)); + Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(), + shiftValues); // 4. Merge high and low i4 values. - auto mergedHiLowOp = rewriter.create(loc, zeroOutLow, shlHigh); + auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh); // 5. Generate a bitcast vector -> vector<2Xxi4>. auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type()); - return rewriter.create(loc, i4VecType, mergedHiLowOp); + return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp); } namespace { @@ -2151,7 +2163,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { Location loc = truncOp.getLoc(); auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type()); Value i8TruncVal = - rewriter.create(loc, i8VecType, srcValue); + arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue); // Rewrite the i8 -> i4 truncation part. Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal); @@ -2199,10 +2211,10 @@ struct RewriteVectorTranspose : OpRewritePattern { // support is available. auto srcNativeVecType = srcSubByteVecType.cloneWith( std::nullopt, rewriter.getIntegerType(minNativeBitwidth)); - Value extOp = rewriter.create(loc, srcNativeVecType, - transposeOp.getVector()); - Value newTranspose = rewriter.create( - loc, extOp, transposeOp.getPermutation()); + Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType, + transposeOp.getVector()); + Value newTranspose = vector::TransposeOp::create( + rewriter, loc, extOp, transposeOp.getPermutation()); VectorType dstSubByteVecType = transposeOp.getResultVectorType(); rewriter.replaceOpWithNewOp(transposeOp, dstSubByteVecType, newTranspose); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index d834a99076834..72352d72bfe77 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -53,15 +53,15 @@ class DecomposeDifferentRankInsertStridedSlice int64_t rankRest = dstType.getRank() - rankDiff; // Extract / insert the subvector of matching rank and InsertStridedSlice // on it. - Value extracted = rewriter.create( - loc, op.getDest(), - getI64SubArray(op.getOffsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); + Value extracted = + ExtractOp::create(rewriter, loc, op.getDest(), + getI64SubArray(op.getOffsets(), /*dropFront=*/0, + /*dropBack=*/rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. - auto stridedSliceInnerOp = rewriter.create( - loc, op.getValueToStore(), extracted, + auto stridedSliceInnerOp = InsertStridedSliceOp::create( + rewriter, loc, op.getValueToStore(), extracted, getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), getI64SubArray(op.getStrides(), /*dropFront=*/0)); @@ -131,8 +131,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle SmallVector offsets(nDest, 0); for (int64_t i = 0; i < nSrc; ++i) offsets[i] = i; - Value scaledSource = rewriter.create( - loc, op.getValueToStore(), op.getValueToStore(), offsets); + Value scaledSource = ShuffleOp::create( + rewriter, loc, op.getValueToStore(), op.getValueToStore(), offsets); // 2. Create a mask where we take the value from scaledSource of dest // depending on the offset. @@ -156,21 +156,21 @@ class ConvertSameRankInsertStridedSliceIntoShuffle off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = - rewriter.create(loc, op.getValueToStore(), idx); + ExtractOp::create(rewriter, loc, op.getValueToStore(), idx); if (isa(extractedSource.getType())) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = - rewriter.create(loc, op.getDest(), off); + ExtractOp::create(rewriter, loc, op.getDest(), off); // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. - extractedSource = rewriter.create( - loc, extractedSource, extractedDest, + extractedSource = InsertStridedSliceOp::create( + rewriter, loc, extractedSource, extractedDest, getI64SubArray(op.getOffsets(), /* dropFront=*/1), getI64SubArray(op.getStrides(), /* dropFront=*/1)); } // 4. Insert the extractedSource into the res vector. - res = rewriter.create(loc, extractedSource, res, off); + res = InsertOp::create(rewriter, loc, extractedSource, res, off); } rewriter.replaceOp(op, res); @@ -250,12 +250,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final SmallVector elements; elements.reserve(size); for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) - elements.push_back(rewriter.create(loc, op.getVector(), i)); + elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i)); - Value result = rewriter.create( - loc, rewriter.getZeroAttr(op.getType())); + Value result = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(op.getType())); for (int64_t i = 0; i < size; ++i) - result = rewriter.create(loc, elements[i], result, i); + result = InsertOp::create(rewriter, loc, elements[i], result, i); rewriter.replaceOp(op, result); return success(); @@ -301,17 +301,17 @@ class DecomposeNDExtractStridedSlice return failure(); // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create(loc, dstType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, + rewriter.getZeroAttr(elemType)); + Value res = SplatOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value one = rewriter.create(loc, op.getVector(), off); - Value extracted = rewriter.create( - loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), + Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); + Value extracted = ExtractStridedSliceOp::create( + rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), getI64SubArray(op.getSizes(), /* dropFront=*/1), getI64SubArray(op.getStrides(), /* dropFront=*/1)); - res = rewriter.create(loc, extracted, res, idx); + res = InsertOp::create(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, res); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index fe17b3c0b2cfc..491b448e9e1e9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -658,7 +658,7 @@ struct LinearizeVectorCreateMask final // The result of the comparison is then multiplied with // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); - auto zero = rewriter.create(loc, 0); + auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); auto isNonZero = rewriter.createOrFold( loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); auto isNonZeroIndex = rewriter.createOrFold( @@ -668,7 +668,7 @@ struct LinearizeVectorCreateMask final loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); auto newMask = - rewriter.create(loc, dstTy, maskSize); + mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); } @@ -710,8 +710,9 @@ struct LinearizeVectorLoad final : public OpConversionPattern { auto linearTy = typeConverter->convertType(vecTy); - auto newLoad = rewriter.create( - loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); + auto newLoad = + vector::LoadOp::create(rewriter, loadOp.getLoc(), linearTy, + adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOp(loadOp, newLoad.getResult()); return success(); } @@ -832,7 +833,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, if (!isa(type) || !isa(value.getType())) return nullptr; - return builder.create(loc, type, value); + return vector::ShapeCastOp::create(builder, loc, type, value); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp index a7403250a069b..8a181a429e41c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp @@ -82,8 +82,8 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, // Replace createMaskOp with an all-true constant. This should result in the // mask being removed in most cases (as xfer ops + vector.mask have folds to // remove all-true masks). - auto allTrue = rewriter.create( - createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); + auto allTrue = vector::ConstantMaskOp::create( + rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); rewriter.replaceAllUsesWith(createMaskOp, allTrue); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c20a1b355996c..2676d254c9b64 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -286,8 +286,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, if (resultType.canonicalizeStridedLayout() == inputType.canonicalizeStridedLayout()) return input; - return rewriter.create(loc, resultType, input, offsets, - sizes, strides); + return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets, + sizes, strides); } /// Returns the number of dims that aren't unit dims. @@ -395,13 +395,13 @@ class TransferReadDropUnitDimsPattern Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); - Value c0 = rewriter.create(loc, 0); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); - Operation *newTransferReadOp = rewriter.create( - loc, reducedVectorType, reducedShapeSource, zeros, identityMap, - transferReadOp.getPadding(), maskOp, + Operation *newTransferReadOp = vector::TransferReadOp::create( + rewriter, loc, reducedVectorType, reducedShapeSource, zeros, + identityMap, transferReadOp.getPadding(), maskOp, rewriter.getBoolArrayAttr(inBounds)); if (maskingOp) { @@ -477,15 +477,15 @@ class TransferWriteDropUnitDimsPattern } Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); - Value c0 = rewriter.create(loc, 0); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); auto shapeCastSrc = rewriter.createOrFold( loc, reducedVectorType, vector); - Operation *newXferWrite = rewriter.create( - loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, - maskOp, rewriter.getBoolArrayAttr(inBounds)); + Operation *newXferWrite = vector::TransferWriteOp::create( + rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros, + identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); if (maskingOp) { auto shapeCastMask = rewriter.createOrFold( @@ -520,7 +520,7 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) collapsedIndices.push_back(i); reassociation.push_back(collapsedIndices); - return rewriter.create(loc, input, reassociation); + return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation); } /// Returns the new indices that collapses the inner dimensions starting from @@ -559,7 +559,7 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, // one would get the following offset: // %offset = %arg0 * 43 OpFoldResult collapsedOffset = - rewriter.create(loc, 0).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); auto collapsedStrides = computeSuffixProduct( ArrayRef(shape.begin() + firstDimToCollapse, shape.end())); @@ -573,8 +573,8 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, if (auto value = dyn_cast(collapsedOffset)) { indicesAfterCollapsing.push_back(value); } else { - indicesAfterCollapsing.push_back(rewriter.create( - loc, *getConstantIntValue(collapsedOffset))); + indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create( + rewriter, loc, *getConstantIntValue(collapsedOffset))); } return indicesAfterCollapsing; @@ -659,8 +659,8 @@ class FlattenContiguousRowMajorTransferReadPattern // 3. Create new vector.transfer_read that reads from the collapsed memref VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); - vector::TransferReadOp flatRead = rewriter.create( - loc, flatVectorType, collapsedSource, collapsedIndices, + vector::TransferReadOp flatRead = vector::TransferReadOp::create( + rewriter, loc, flatVectorType, collapsedSource, collapsedIndices, transferReadOp.getPadding(), collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); @@ -757,10 +757,10 @@ class FlattenContiguousRowMajorTransferWritePattern VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); Value flatVector = - rewriter.create(loc, flatVectorType, vector); - vector::TransferWriteOp flatWrite = - rewriter.create( - loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); + vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector); + vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create( + rewriter, loc, flatVector, collapsedSource, collapsedIndices, + collapsedMap); flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); // 4. Replace the old transfer_write with the new one writing the @@ -846,8 +846,8 @@ class RewriteScalarExtractOfTransferRead if (auto value = dyn_cast(composedIdx)) { newIndices[idx] = value; } else { - newIndices[idx] = rewriter.create( - extractOp.getLoc(), *getConstantIntValue(composedIdx)); + newIndices[idx] = arith::ConstantIndexOp::create( + rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx)); } } if (isa(xferOp.getBase().getType())) { @@ -883,8 +883,8 @@ class RewriteScalarWrite : public OpRewritePattern { if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); // Only float and integer element types are supported. - Value scalar = - rewriter.create(xferOp.getLoc(), xferOp.getVector()); + Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(), + xferOp.getVector()); // Construct a scalar store. if (isa(xferOp.getBase().getType())) { rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index eee090d495c17..05b00744beea2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -57,12 +57,12 @@ static Value createInBoundsCond(RewriterBase &b, if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz) return; Value cond = - b.create(loc, arith::CmpIPredicate::sle, - getValueOrCreateConstantIndexOp(b, loc, sum), - getValueOrCreateConstantIndexOp(b, loc, dimSz)); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle, + getValueOrCreateConstantIndexOp(b, loc, sum), + getValueOrCreateConstantIndexOp(b, loc, dimSz)); // Conjunction over all dims for which we are in-bounds. if (inBoundsCond) - inBoundsCond = b.create(loc, inBoundsCond, cond); + inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond); else inBoundsCond = cond; }); @@ -170,11 +170,12 @@ static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, sourceType = MemRefType::get( sourceType.getShape(), sourceType.getElementType(), sourceType.getLayout(), compatibleMemRefType.getMemorySpace()); - res = b.create(memref.getLoc(), sourceType, res); + res = + memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res); } if (sourceType == compatibleMemRefType) return res; - return b.create(memref.getLoc(), compatibleMemRefType, res); + return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res); } /// Operates under a scoped context to build the intersection between the @@ -196,16 +197,17 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { using MapList = ArrayRef>; Value dimMemRef = - b.create(xferOp.getLoc(), xferOp.getBase(), indicesIdx); - Value dimAlloc = b.create(loc, alloc, resultIdx); + memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx); + Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx); Value index = xferOp.getIndices()[indicesIdx]; AffineExpr i, j, k; bindDims(xferOp.getContext(), i, j, k); SmallVector maps = AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext()); // affine_min(%dimMemRef - %index, %dimAlloc) - Value affineMin = b.create( - loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); + Value affineMin = + affine::AffineMinOp::create(b, loc, index.getType(), maps[0], + ValueRange{dimMemRef, index, dimAlloc}); sizes.push_back(affineMin); }); @@ -213,10 +215,10 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; })); SmallVector destIndices(memrefRank, b.getIndexAttr(0)); SmallVector strides(memrefRank, b.getIndexAttr(1)); - auto copySrc = b.create( - loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides); - auto copyDest = b.create( - loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides); + auto copySrc = memref::SubViewOp::create( + b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides); + auto copyDest = memref::SubViewOp::create( + b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides); return std::make_pair(copySrc, copyDest); } @@ -244,32 +246,32 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b.create( - loc, inBoundsCond, + return scf::IfOp::create( + b, loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{xferOp.getPadding()}, - ValueRange{alloc}); + linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()}, + ValueRange{alloc}); // Take partial subview of memref which guarantees no dimension // overflows. IRRewriter rewriter(b); std::pair copyArgs = createSubViewIntersection( rewriter, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); + memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second); Value casted = castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }); } @@ -297,30 +299,30 @@ static scf::IfOp createFullPartialVectorTransferRead( Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); scf::IfOp fullPartialIfOp; - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b.create( - loc, inBoundsCond, + return scf::IfOp::create( + b, loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { Operation *newXfer = b.clone(*xferOp.getOperation()); Value vector = cast(newXfer).getVector(); - b.create( - loc, vector, - b.create( - loc, MemRefType::get({}, vector.getType()), alloc)); + memref::StoreOp::create( + b, loc, vector, + vector::TypeCastOp::create( + b, loc, MemRefType::get({}, vector.getType()), alloc)); Value casted = castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }); } @@ -344,7 +346,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); return b .create( @@ -354,7 +356,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { Value casted = @@ -362,7 +364,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }) ->getResults(); } @@ -384,15 +386,15 @@ static void createFullPartialLinalgCopy(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc) { Location loc = xferOp.getLoc(); - auto notInBounds = b.create( - loc, inBoundsCond, b.create(loc, true, 1)); - b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + auto notInBounds = arith::XOrIOp::create( + b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1)); + scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) { IRRewriter rewriter(b); std::pair copyArgs = createSubViewIntersection( rewriter, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); - b.create(loc, ValueRange{}); + memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second); + scf::YieldOp::create(b, loc, ValueRange{}); }); } @@ -413,18 +415,18 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b, Value inBoundsCond, Value alloc) { Location loc = xferOp.getLoc(); - auto notInBounds = b.create( - loc, inBoundsCond, b.create(loc, true, 1)); - b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + auto notInBounds = arith::XOrIOp::create( + b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1)); + scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) { IRMapping mapping; - Value load = b.create( - loc, - b.create( - loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), + Value load = memref::LoadOp::create( + b, loc, + vector::TypeCastOp::create( + b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), ValueRange()); mapping.map(xferOp.getVector(), load); b.clone(*xferOp.getOperation(), mapping); - b.create(loc, ValueRange{}); + scf::YieldOp::create(b, loc, ValueRange{}); }); } @@ -554,9 +556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( b.setInsertionPointToStart(&scope->getRegion(0).front()); auto shape = xferOp.getVectorType().getShape(); Type elementType = xferOp.getVectorType().getElementType(); - alloc = b.create(scope->getLoc(), - MemRefType::get(shape, elementType), - ValueRange{}, b.getI64IntegerAttr(32)); + alloc = memref::AllocaOp::create(b, scope->getLoc(), + MemRefType::get(shape, elementType), + ValueRange{}, b.getI64IntegerAttr(32)); } MemRefType compatibleMemRefType = diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index fe2707629d82e..73ca327bb49c5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -381,8 +381,8 @@ FailureOr combineContractAndBroadcast(vector::ContractionOp contractOp, if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) return failure(); - Operation *newOp = rewriter.create( - contractOp.getLoc(), lhs, rhs, contractOp.getAcc(), + Operation *newOp = vector::ContractionOp::create( + rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(), rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); // Handle the mask. @@ -534,8 +534,8 @@ struct ReorderElementwiseOpsOnTranspose final // This is a constant. Create a reverse transpose op for it. auto vectorType = srcType.clone(cast(operand.getType()).getElementType()); - srcValues.push_back(rewriter.create( - operand.getLoc(), vectorType, operand, invOrder)); + srcValues.push_back(vector::TransposeOp::create( + rewriter, operand.getLoc(), vectorType, operand, invOrder)); } } @@ -608,20 +608,20 @@ struct BubbleDownVectorBitCastForExtract // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> Location loc = extractOp.getLoc(); - Value packedValue = rewriter.create( - loc, castOp.getSource(), index / expandRatio); + Value packedValue = vector::ExtractOp::create( + rewriter, loc, castOp.getSource(), index / expandRatio); Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); - Value zero = rewriter.create( - loc, packedVecType, rewriter.getZeroAttr(packedVecType)); - packedValue = rewriter.create(loc, packedValue, zero, - /*position=*/0); + Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType, + rewriter.getZeroAttr(packedVecType)); + packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero, + /*position=*/0); // Cast it to a vector with the desired scalar's type. // E.g. f32 -> vector<2xf16> VectorType packedType = VectorType::get({expandRatio}, castDstType.getElementType()); Value castedValue = - rewriter.create(loc, packedType, packedValue); + vector::BitCastOp::create(rewriter, loc, packedType, packedValue); // Finally extract the desired scalar. rewriter.replaceOpWithNewOp(extractOp, castedValue, @@ -700,9 +700,9 @@ struct BubbleDownBitCastForStridedSliceExtract VectorType newExtractType = VectorType::get(dims, castSrcType.getElementType()); - auto newExtractOp = rewriter.create( - extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, - newSizes, extractOp.getStrides()); + auto newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(), + newOffsets, newSizes, extractOp.getStrides()); rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), newExtractOp); @@ -761,8 +761,9 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern { isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; VectorType newCastSrcType = VectorType::get(srcDims, castDstType.getElementType()); - auto newCastSrcOp = rewriter.create( - bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore()); + auto newCastSrcOp = + vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType, + insertOp.getValueToStore()); SmallVector dstDims(insertOp.getDestVectorType().getShape()); dstDims.back() = @@ -771,8 +772,8 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern { VectorType::get(dstDims, castDstType.getElementType()); // Bitcast the destination. - auto newCastDstOp = rewriter.create( - bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); + auto newCastDstOp = vector::BitCastOp::create( + rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); // Generate new insert. rewriter.replaceOpWithNewOp( @@ -852,8 +853,9 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType newCastSrcType = VectorType::get(srcDims, castDstType.getElementType()); - auto newCastSrcOp = rewriter.create( - bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore()); + auto newCastSrcOp = + vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType, + insertOp.getValueToStore()); SmallVector dstDims = llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); @@ -861,8 +863,8 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType newCastDstType = VectorType::get(dstDims, castDstType.getElementType()); - auto newCastDstOp = rewriter.create( - bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); + auto newCastDstOp = vector::BitCastOp::create( + rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); rewriter.replaceOpWithNewOp( bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, @@ -936,9 +938,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern { Type elemType = castDstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); - Value zero = rewriter.create( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create(loc, castDstType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, + rewriter.getZeroAttr(elemType)); + Value res = SplatOp::create(rewriter, loc, castDstType, zero); SmallVector sliceShape = {castDstLastDim}; SmallVector strides = {1}; @@ -947,13 +949,13 @@ struct BreakDownVectorBitCast : public OpRewritePattern { castDstType.getElementType()); for (int i = 0, e = shrinkRatio; i < e; ++i) { - Value extracted = rewriter.create( - loc, bitcastOp.getSource(), ArrayRef{i * castDstLastDim}, - sliceShape, strides); + Value extracted = ExtractStridedSliceOp::create( + rewriter, loc, bitcastOp.getSource(), + ArrayRef{i * castDstLastDim}, sliceShape, strides); Value bitcast = - rewriter.create(loc, newCastDstType, extracted); - res = rewriter.create( - loc, bitcast, res, + BitCastOp::create(rewriter, loc, newCastDstType, extracted); + res = InsertStridedSliceOp::create( + rewriter, loc, bitcast, res, ArrayRef{i * castDstLastDim / shrinkRatio}, strides); } rewriter.replaceOp(bitcastOp, res); @@ -1103,7 +1105,7 @@ class ExtractOpFromElementwise final Location loc = eltwise->getLoc(); SmallVector pos = op.getMixedPosition(); for (Value arg : eltwise->getOperands()) { - Value newArg = rewriter.create(loc, arg, pos); + Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos); mapping.map(arg, newArg); } @@ -1292,19 +1294,19 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, indicesAttr = rewriter.getI64VectorAttr( llvm::to_vector<4>(llvm::seq(0, dim))); } - Value indices = rewriter.create(loc, indicesAttr); + Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr); // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = rewriter.create(loc, indices.getType(), o); - indices = rewriter.create(loc, ov, indices); + Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - rewriter.create(loc, indices.getType(), bound); - return rewriter.create(loc, arith::CmpIPredicate::slt, indices, - bounds); + vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + indices, bounds); } template @@ -1335,15 +1337,15 @@ struct MaterializeTransferMask : public OpRewritePattern { Value off = xferOp.getIndices()[lastIndex]; Value dim = vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex); - Value b = rewriter.create(loc, dim.getType(), dim, off); - Value mask = rewriter.create( - loc, + Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off); + Value mask = vector::CreateMaskOp::create( + rewriter, loc, VectorType::get(vtp.getShape(), rewriter.getI1Type(), vtp.getScalableDims()), b); if (xferOp.getMask()) { // Intersect the in-bounds with the mask specified as an op parameter. - mask = rewriter.create(loc, mask, xferOp.getMask()); + mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask()); } rewriter.modifyOpInPlace(xferOp, [&]() { @@ -1548,12 +1550,13 @@ class DropInnerMostUnitDimsTransferRead strides); ArrayAttr inBoundsAttr = rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); - Value rankedReducedView = rewriter.create( - loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides); + Value rankedReducedView = + memref::SubViewOp::create(rewriter, loc, resultMemrefType, + readOp.getBase(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); - Value result = rewriter.create( - loc, resultTargetVecType, rankedReducedView, + Value result = vector::TransferReadOp::create( + rewriter, loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), readOp.getPadding(), // TODO: support mask. @@ -1639,8 +1642,9 @@ class DropInnerMostUnitDimsTransferWrite ArrayAttr inBoundsAttr = rewriter.getArrayAttr( writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); - Value rankedReducedView = rewriter.create( - loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides); + Value rankedReducedView = + memref::SubViewOp::create(rewriter, loc, resultMemrefType, + writeOp.getBase(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); @@ -1708,21 +1712,21 @@ struct CanonicalizeContractMatmulToMMT final auto createTranspose = [&rewriter, loc](Value mat) -> Value { if (auto sext = mat.getDefiningOp()) { Value trans = - rewriter.create(loc, sext.getIn(), perm); + vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm); VectorType newType = cast(trans.getType()) .clone(cast(mat.getType()).getElementType()); - return rewriter.create(loc, newType, trans); + return arith::ExtSIOp::create(rewriter, loc, newType, trans); } if (auto zext = mat.getDefiningOp()) { Value trans = - rewriter.create(loc, zext.getIn(), perm); + vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm); VectorType newType = VectorType::get(cast(trans.getType()).getShape(), cast(mat.getType()).getElementType()); - return rewriter.create(loc, newType, trans); + return arith::ExtUIOp::create(rewriter, loc, newType, trans); } - return rewriter.create(loc, mat, perm); + return vector::TransposeOp::create(rewriter, loc, mat, perm); }; if (maps == infer({{m, k}, {k, n}, {m, n}})) { @@ -1836,8 +1840,8 @@ struct ChainedReduction final : OpRewritePattern { vAdd = rewriter.createOrFold( loc, parentReduction.getVector(), op.getVector()); } else { - vAdd = rewriter.create(loc, parentReduction.getVector(), - op.getVector()); + vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(), + op.getVector()); } rewriter.replaceOpWithNewOp(op, op.getKind(), vAdd, parentReduction.getAcc()); @@ -1925,7 +1929,7 @@ struct DropUnitDimFromElementwiseOps final if (newVType == opVectorType) return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); - auto opSC = rewriter.create(loc, newVType, operand); + auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand); newOperands.push_back(opSC); } @@ -2004,11 +2008,11 @@ struct DropUnitDimsFromTransposeOp final Location loc = op.getLoc(); // Drop the unit dims via shape_cast. - auto dropDimsShapeCast = rewriter.create( - loc, sourceTypeWithoutUnitDims, op.getVector()); + auto dropDimsShapeCast = vector::ShapeCastOp::create( + rewriter, loc, sourceTypeWithoutUnitDims, op.getVector()); // Create the new transpose. auto transposeWithoutUnitDims = - rewriter.create(loc, dropDimsShapeCast, newPerm); + vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm); // Restore the unit dims via shape cast. rewriter.replaceOpWithNewOp( op, op.getResultVectorType(), transposeWithoutUnitDims); @@ -2059,7 +2063,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern { // Create a new ForOp with that iter operand replaced. auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) { - return b.create(loc, type, source); + return vector::ShapeCastOp::create(b, loc, type, source); }; Value replacement = @@ -2111,8 +2115,8 @@ struct ReduceRedundantZero final : OpRewritePattern { if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) return failure(); - auto newAdd = rewriter.create(vAdd.getLoc(), addLhs.getLhs(), - vAdd.getRhs()); + auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(), + addLhs.getLhs(), vAdd.getRhs()); rewriter.replaceOpWithNewOp(op, op.getKind(), newAdd, op.getAcc()); return success(); @@ -2154,8 +2158,8 @@ struct BreakDownVectorReduction final : OpRewritePattern { Location loc = op.getLoc(); SmallVector extracted(numElems, nullptr); for (auto [idx, extractedElem] : llvm::enumerate(extracted)) - extractedElem = rewriter.create( - loc, op.getVector(), static_cast(idx)); + extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(), + static_cast(idx)); Value res = extracted.front(); for (auto extractedElem : llvm::drop_begin(extracted)) @@ -2234,8 +2238,8 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) return failure(); - return rewriter.create( - mulOp->getLoc(), resType, broadcastedLhs.getSource(), + return vector::OuterProductOp::create( + rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(), broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..fceba65fa3e3a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -49,7 +49,7 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, getAffineConstantExpr(elementOffsets[dim.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); slicedIndices[pos] = - builder.create(loc, map, indices[pos]); + affine::AffineApplyOp::create(builder, loc, map, indices[pos]); } return slicedIndices; } @@ -68,9 +68,9 @@ static SmallVector sliceLoadStoreIndices(PatternRewriter &rewriter, auto start = indices.size() - offsets.size(); for (auto [i, offset] : llvm::enumerate(offsets)) { if (offset != 0) { - indices[start + i] = rewriter.create( - loc, originalIndices[start + i], - rewriter.create(loc, offset)); + indices[start + i] = arith::AddIOp::create( + rewriter, loc, originalIndices[start + i], + arith::ConstantIndexOp::create(rewriter, loc, offset)); } } return indices; @@ -172,8 +172,9 @@ struct UnrollTransferReadPattern ArrayRef originalSize = readOp.getVectorType().getShape(); // Prepare the result vector; - Value result = rewriter.create( - loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + Value result = + arith::ConstantOp::create(rewriter, loc, sourceVectorType, + rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), @@ -185,8 +186,8 @@ struct UnrollTransferReadPattern SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); - auto slicedRead = rewriter.create( - loc, targetType, readOp.getBase(), indices, + auto slicedRead = vector::TransferReadOp::create( + rewriter, loc, targetType, readOp.getBase(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); @@ -236,9 +237,10 @@ struct UnrollTransferWritePattern SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, writeOp.getPermutationMap(), loc, rewriter); - Operation *slicedWrite = rewriter.create( - loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(), - indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); + Operation *slicedWrite = vector::TransferWriteOp::create( + rewriter, loc, slicedVector, + resultTensor ? resultTensor : writeOp.getBase(), indices, + writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); // For the tensor case update the destination for the next transfer write. if (!slicedWrite->getResults().empty()) resultTensor = slicedWrite->getResult(0); @@ -348,8 +350,8 @@ struct UnrollContractionPattern accCache[dstOffets] = newOp->getResult(0); } // Assemble back the accumulator into a single vector. - Value result = rewriter.create( - loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, + rewriter.getZeroAttr(dstVecType)); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); result = rewriter.createOrFold( @@ -427,8 +429,8 @@ struct UnrollMultiReductionPattern accCache[destOffset] = result; } // Assemble back the accumulator into a single vector. - Value result = rewriter.create( - loc, reductionOp.getDestType(), + Value result = arith::ConstantOp::create( + rewriter, loc, reductionOp.getDestType(), rewriter.getZeroAttr(reductionOp.getDestType())); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); @@ -468,8 +470,8 @@ struct UnrollElementwisePattern : public RewritePattern { op, "expected input vector rank to match target shape rank"); Location loc = op->getLoc(); // Prepare the result vector. - Value result = rewriter.create( - loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, + rewriter.getZeroAttr(dstVecType)); SmallVector strides(targetShape->size(), 1); VectorType newVecType = VectorType::get(*targetShape, dstVecType.getElementType()); @@ -567,8 +569,9 @@ struct UnrollTransposePattern : public OpRewritePattern { ArrayRef originalSize = originalVectorType.getShape(); // Prepare the result vector; - Value result = rewriter.create( - loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); + Value result = + arith::ConstantOp::create(rewriter, loc, originalVectorType, + rewriter.getZeroAttr(originalVectorType)); ArrayRef permutation = transposeOp.getPermutation(); // Unroll the computation. @@ -618,8 +621,9 @@ struct UnrollGatherPattern : public OpRewritePattern { ArrayRef originalSize = gatherOp.getVectorType().getShape(); // Prepare the result vector; - Value result = rewriter.create( - loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + Value result = + arith::ConstantOp::create(rewriter, loc, sourceVectorType, + rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); @@ -638,8 +642,8 @@ struct UnrollGatherPattern : public OpRewritePattern { rewriter.createOrFold( loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); - auto slicedGather = rewriter.create( - loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), + auto slicedGather = vector::GatherOp::create( + rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), indexSubVec, maskSubVec, passThruSubVec); result = rewriter.createOrFold( @@ -671,8 +675,8 @@ struct UnrollLoadPattern : public OpRewritePattern { ArrayRef originalShape = vecType.getShape(); SmallVector strides(targetShape->size(), 1); - Value result = rewriter.create( - loc, vecType, rewriter.getZeroAttr(vecType)); + Value result = arith::ConstantOp::create(rewriter, loc, vecType, + rewriter.getZeroAttr(vecType)); SmallVector loopOrder = getUnrollOrder(originalShape.size(), loadOp, options); @@ -684,8 +688,8 @@ struct UnrollLoadPattern : public OpRewritePattern { StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets); - Value slicedLoad = rewriter.create( - loc, targetVecType, loadOp.getBase(), indices); + Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType, + loadOp.getBase(), indices); result = rewriter.createOrFold( loc, slicedLoad, result, offsets, strides); } @@ -727,7 +731,7 @@ struct UnrollStorePattern : public OpRewritePattern { sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets); Value slice = rewriter.createOrFold( loc, vector, offsets, *targetShape, strides); - rewriter.create(loc, slice, base, indices); + vector::StoreOp::create(rewriter, loc, slice, base, indices); } rewriter.eraseOp(storeOp); return success(); @@ -755,8 +759,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern { VectorType resType = broadcastOp.getResultVectorType(); VectorType targetType = resType.cloneWith(*targetShape, resType.getElementType()); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); SmallVector originalShape = *broadcastOp.getShapeForUnroll(); SmallVector strides(originalShape.size(), 1); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 7e4984582b373..ac542ddd7d1c4 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -331,7 +331,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); SmallVector inBoundsVal(readRank, true); if (useInBoundsInsteadOfMasking) { @@ -341,8 +341,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && ShapedType::isStatic(sourceShape[i]); } - auto transferReadOp = builder.create( - loc, + auto transferReadOp = vector::TransferReadOp::create( + builder, loc, /*vectorType=*/vectorType, /*source=*/source, /*indices=*/SmallVector(readRank, zero), @@ -356,7 +356,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type()); Value mask = - builder.create(loc, maskType, mixedSourceDims); + vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) ->getResult(0); }