-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][NFC] update mlir/Dialect
create APIs (24/n)
#149931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
makslevental
merged 1 commit into
llvm:main
from
makslevental:makslevental/update-create-24n
Jul 22, 2025
Merged
[mlir][NFC] update mlir/Dialect
create APIs (24/n)
#149931
makslevental
merged 1 commit into
llvm:main
from
makslevental:makslevental/update-create-24n
Jul 22, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
✅ With the latest revision this PR passed the C/C++ code formatter. |
See llvm#147168 for more info.
05d29b0
to
950a185
Compare
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesSee #147168 for more info. Patch is 198.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149931.diff 26 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 56f748fbbe1d6..4c00fb58e4d30 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) {
/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
- builder.create<vector::YieldOp>(loc);
+ vector::YieldOp::create(builder, loc);
}
// Helper for verifying combining kinds in contractions and reductions.
@@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction
VectorType newMaskType =
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
dstVecType.getScalableDims());
- mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+ mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
}
- cast = rewriter.create<vector::ShapeCastOp>(
- loc, reductionOp.getDestType(), reductionOp.getSource());
+ cast = vector::ShapeCastOp::create(
+ rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
if (mask)
- mask = rewriter.create<vector::ExtractOp>(loc, mask);
- cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
+ mask = vector::ExtractOp::create(rewriter, loc, mask);
+ cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
}
Value result =
@@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
switch (op) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::addi:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::ADD, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::ADD, vector);
case arith::AtomicRMWKind::mulf:
case arith::AtomicRMWKind::muli:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MUL, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MUL, vector);
case arith::AtomicRMWKind::minimumf:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINIMUMF, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINIMUMF, vector);
case arith::AtomicRMWKind::mins:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINSI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINSI, vector);
case arith::AtomicRMWKind::minu:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINUI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINUI, vector);
case arith::AtomicRMWKind::maximumf:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXIMUMF, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXIMUMF, vector);
case arith::AtomicRMWKind::maxs:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXSI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXSI, vector);
case arith::AtomicRMWKind::maxu:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXUI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXUI, vector);
case arith::AtomicRMWKind::andi:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::AND, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::AND, vector);
case arith::AtomicRMWKind::ori:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::OR, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::OR, vector);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
@@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Location loc = reductionOp.getLoc();
if (mask)
- mask = rewriter.create<ExtractOp>(loc, mask);
- Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
+ mask = ExtractOp::create(rewriter, loc, mask);
+ Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final
// greater than the vector dim size.
IntegerAttr offsetAttr =
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
- Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+ Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
Value sliceMaskDimSize =
- rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+ arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
sliceMaskDimSizes.push_back(sliceMaskDimSize);
}
// Add unchanged dimensions.
@@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final
sizes[i] = 1;
}
}
- source = rewriter.create<ExtractStridedSliceOp>(
- op->getLoc(), source, offsets, sizes,
+ source = ExtractStridedSliceOp::create(
+ rewriter, op->getLoc(), source, offsets, sizes,
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
@@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
- Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
- extractOffsets);
+ Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
+ extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
@@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding, /*mask=*/Value(), inBoundsAttr);
}
@@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
}
@@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding,
/*mask=*/Value(), inBoundsAttr);
@@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast
VectorType broadcastedType = VectorType::get(
broadcastShape, defWrite.getVectorType().getElementType(),
broadcastScalableFlags);
- vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
+ vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
transposePerm);
@@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite
// Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
// Set all in_bounds to false and let the folder infer them.
SmallVector<bool> newInBounds(vectorShape.size(), false);
- auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
- extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
- insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
- insertOp.getMixedStrides());
- auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
- transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
- transferOp.getIndices(), transferOp.getPermutationMapAttr(),
+ auto newExtractOp = tensor::ExtractSliceOp::create(
+ rewriter, extractOp.getLoc(), insertOp.getSourceType(),
+ insertOp.getDest(), insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ auto newTransferWriteOp = TransferWriteOp::create(
+ rewriter, transferOp.getLoc(), transferOp.getVector(),
+ newExtractOp.getResult(), transferOp.getIndices(),
+ transferOp.getPermutationMapAttr(),
rewriter.getBoolArrayAttr(newInBounds));
rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
@@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
- opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
+ vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
}
LogicalResult MaskOp::verify() {
@@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
// Create a block and move the op to that block.
insBlock->getOperations().splice(
insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
- builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
+ YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
}
/// Creates a vector.mask operation around a maskable operation. Returns the
@@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
if (!mask)
return maskableOp;
if (passthru)
- return builder.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, passthru,
- maskableOp, createMaskOpRegion);
- return builder.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, maskableOp,
- createMaskOpRegion);
+ return MaskOp::create(builder, maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, passthru,
+ maskableOp, createMaskOpRegion);
+ return MaskOp::create(builder, maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, maskableOp,
+ createMaskOpRegion);
}
/// Creates a vector select operation that picks values from `newValue` or
@@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
if (!mask)
return newValue;
- return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
- mask, newValue, passthru);
+ return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
+ mask, newValue, passthru);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 9da051150e409..66196194b0585 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -116,8 +116,8 @@ struct TransferWriteOpInterface
getBuffer(rewriter, writeOp.getBase(), options, state);
if (failed(resultBuffer))
return failure();
- rewriter.create<vector::TransferWriteOp>(
- writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
+ vector::TransferWriteOp::create(
+ rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
writeOp.getMask(), writeOp.getInBoundsAttr());
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
@@ -241,8 +241,9 @@ struct MaskOpInterface
// Create a new vector.mask op.
ValueRange newYieldedValuesRange(newYieldedValues);
TypeRange newResultTypes(newYieldedValuesRange);
- auto newOp = rewriter.create<vector::MaskOp>(
- op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
+ auto newOp = vector::MaskOp::create(
+ rewriter, op->getLoc(), newResultTypes, maskOp.getMask(),
+ maskOp.getPassthru(),
/*maskableOp=*/nullptr,
/*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
newOp.getRegion().takeBody(maskOp.getMaskRegion());
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 89930a6bd35fa..4c3a04cfb5bfa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
VectorType::get(shape, resultType.getElementType(), scalableDims);
Location loc = op.getLoc();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resultType);
for (auto position : *unrollIterator) {
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
Value bitcast =
- rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+ vector::BitCastOp::create(rewriter, loc, bitcastResType, extract);
result =
- rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+ vector::InsertOp::create(rewriter, loc, bitcast, result, position);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 11dcfe421e0c4..cb8e566869cfd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
@@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Duplication.
VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
- rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
- Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
+ vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
+ Value result = ub::PoisonOp::create(rewriter, loc, dstType);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
rewriter.replaceOp(op, result);
return success();
}
@@ -111,13 +111,13 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType,
dstType.getScalableDims().drop_front());
- Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
+ Value result = ub::PoisonOp::create(rewriter, loc, dstType);
if (m == 0) {
// Stetch at start.
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
+ Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
} else {
// Stetch not at start.
if (dstType.getScalableDims()[0]) {
@@ -125,9 +125,9 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
return failure();
}
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
+ Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
}
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index fc6c90f5132c7..65702ffa152d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -81,17 +81,17 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
// At extraction dimension?
if (index == 0)
- return rewriter.create<vector::ExtractOp>(loc, val, pos);
+ return vector::ExtractOp::create(rewriter, loc, val, pos);
// Unroll leading dimensions.
VectorType vType = VectorType::Builder(type).dropDim(0);
VectorType resType = VectorType::Builder(type).dropDim(index);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ ...
[truncated]
|
Groverkss
approved these changes
Jul 22, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
See #147168 for more info.