Skip to content

[mlir][NFC] update mlir/Dialect create APIs (24/n) #149931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 56 additions & 55 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) {
/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<vector::YieldOp>(loc);
vector::YieldOp::create(builder, loc);
}

// Helper for verifying combining kinds in contractions and reductions.
Expand Down Expand Up @@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction
VectorType newMaskType =
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
dstVecType.getScalableDims());
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
}
cast = rewriter.create<vector::ShapeCastOp>(
loc, reductionOp.getDestType(), reductionOp.getSource());
cast = vector::ShapeCastOp::create(
rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
if (mask)
mask = rewriter.create<vector::ExtractOp>(loc, mask);
cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
mask = vector::ExtractOp::create(rewriter, loc, mask);
cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
}

Value result =
Expand Down Expand Up @@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
switch (op) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::addi:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::ADD, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::ADD, vector);
case arith::AtomicRMWKind::mulf:
case arith::AtomicRMWKind::muli:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MUL, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MUL, vector);
case arith::AtomicRMWKind::minimumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINIMUMF, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MINIMUMF, vector);
case arith::AtomicRMWKind::mins:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINSI, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MINSI, vector);
case arith::AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINUI, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MINUI, vector);
case arith::AtomicRMWKind::maximumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MAXIMUMF, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MAXIMUMF, vector);
case arith::AtomicRMWKind::maxs:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MAXSI, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MAXSI, vector);
case arith::AtomicRMWKind::maxu:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MAXUI, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MAXUI, vector);
case arith::AtomicRMWKind::andi:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::AND, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::AND, vector);
case arith::AtomicRMWKind::ori:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::OR, vector);
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::OR, vector);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
Expand Down Expand Up @@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {

Location loc = reductionOp.getLoc();
if (mask)
mask = rewriter.create<ExtractOp>(loc, mask);
Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
mask = ExtractOp::create(rewriter, loc, mask);
Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());

if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
Expand Down Expand Up @@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final
// greater than the vector dim size.
IntegerAttr offsetAttr =
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
Value sliceMaskDimSize =
rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
sliceMaskDimSizes.push_back(sliceMaskDimSize);
}
// Add unchanged dimensions.
Expand Down Expand Up @@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final
sizes[i] = 1;
}
}
source = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), source, offsets, sizes,
source = ExtractStridedSliceOp::create(
rewriter, op->getLoc(), source, offsets, sizes,
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
Expand Down Expand Up @@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final

SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
extractOffsets);
Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
Expand Down Expand Up @@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,

Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding, /*mask=*/Value(), inBoundsAttr);
}
Expand All @@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
}
Expand All @@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding,
/*mask=*/Value(), inBoundsAttr);
Expand Down Expand Up @@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast
VectorType broadcastedType = VectorType::get(
broadcastShape, defWrite.getVectorType().getElementType(),
broadcastScalableFlags);
vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
transposePerm);
Expand Down Expand Up @@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite
// Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
// Set all in_bounds to false and let the folder infer them.
SmallVector<bool> newInBounds(vectorShape.size(), false);
auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
insertOp.getMixedStrides());
auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
auto newExtractOp = tensor::ExtractSliceOp::create(
rewriter, extractOp.getLoc(), insertOp.getSourceType(),
insertOp.getDest(), insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
auto newTransferWriteOp = TransferWriteOp::create(
rewriter, transferOp.getLoc(), transferOp.getVector(),
newExtractOp.getResult(), transferOp.getIndices(),
transferOp.getPermutationMapAttr(),
rewriter.getBoolArrayAttr(newInBounds));
rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
Expand Down Expand Up @@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
}

LogicalResult MaskOp::verify() {
Expand Down Expand Up @@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
// Create a block and move the op to that block.
insBlock->getOperations().splice(
insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
}

/// Creates a vector.mask operation around a maskable operation. Returns the
Expand All @@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
if (!mask)
return maskableOp;
if (passthru)
return builder.create<MaskOp>(maskableOp->getLoc(),
maskableOp->getResultTypes(), mask, passthru,
maskableOp, createMaskOpRegion);
return builder.create<MaskOp>(maskableOp->getLoc(),
maskableOp->getResultTypes(), mask, maskableOp,
createMaskOpRegion);
return MaskOp::create(builder, maskableOp->getLoc(),
maskableOp->getResultTypes(), mask, passthru,
maskableOp, createMaskOpRegion);
return MaskOp::create(builder, maskableOp->getLoc(),
maskableOp->getResultTypes(), mask, maskableOp,
createMaskOpRegion);
}

/// Creates a vector select operation that picks values from `newValue` or
Expand All @@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
if (!mask)
return newValue;

return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
mask, newValue, passthru);
return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
mask, newValue, passthru);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ struct TransferWriteOpInterface
getBuffer(rewriter, writeOp.getBase(), options, state);
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
vector::TransferWriteOp::create(
rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
writeOp.getMask(), writeOp.getInBoundsAttr());
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
Expand Down Expand Up @@ -241,8 +241,9 @@ struct MaskOpInterface
// Create a new vector.mask op.
ValueRange newYieldedValuesRange(newYieldedValues);
TypeRange newResultTypes(newYieldedValuesRange);
auto newOp = rewriter.create<vector::MaskOp>(
op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
auto newOp = vector::MaskOp::create(
rewriter, op->getLoc(), newResultTypes, maskOp.getMask(),
maskOp.getPassthru(),
/*maskableOp=*/nullptr,
/*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
newOp.getRegion().takeBody(maskOp.getMaskRegion());
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
VectorType::get(shape, resultType.getElementType(), scalableDims);

Location loc = op.getLoc();
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
Value result = ub::PoisonOp::create(rewriter, loc, resultType);
for (auto position : *unrollIterator) {
Value extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
Value bitcast =
rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
vector::BitCastOp::create(rewriter, loc, bitcastResType, extract);
result =
rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
vector::InsertOp::create(rewriter, loc, bitcast, result, position);
}

rewriter.replaceOp(op, result);
Expand Down
22 changes: 11 additions & 11 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {

// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
Expand All @@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Duplication.
VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
Value result = ub::PoisonOp::create(rewriter, loc, dstType);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -111,23 +111,23 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType,
dstType.getScalableDims().drop_front());
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
Value result = ub::PoisonOp::create(rewriter, loc, dstType);
if (m == 0) {
// Stetch at start.
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
} else {
// Stetch not at start.
if (dstType.getScalableDims()[0]) {
// TODO: For scalable vectors we should emit an scf.for loop.
return failure();
}
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
}
}
rewriter.replaceOp(op, result);
Expand Down
Loading
Loading