Skip to content

[mlir][NFC] update mlir/Dialect create APIs (18/n) #149925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
// TODO: support more types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](MemRefType t) {
return builder.create<memref::AllocaOp>(getLoc(), t);
return memref::AllocaOp::create(builder, getLoc(), t);
})
.Default([&](Type t) {
return builder.create<arith::ConstantOp>(getLoc(), t,
builder.getZeroAttr(t));
return arith::ConstantOp::create(builder, getLoc(), t,
builder.getZeroAttr(t));
});
}

Expand Down Expand Up @@ -135,7 +135,7 @@ DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(), elemType});
Expand Down
38 changes: 20 additions & 18 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());

// Create and insert the alloc op for the new memref.
auto newAlloc = rewriter.create<AllocLikeOp>(
alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
alloc.getAlignmentAttr());
auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
dynamicSizes, alloc.getSymbolOperands(),
alloc.getAlignmentAttr());
// Insert a cast so we have the same type as the old alloc.
rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
return success();
Expand Down Expand Up @@ -797,7 +797,7 @@ void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) {
auto loc = result.location;
Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
build(builder, result, source, indexValue);
}

Expand Down Expand Up @@ -1044,9 +1044,9 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
rewriter.setInsertionPointAfter(reshape);
Location loc = dim.getLoc();
Value load =
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
if (load.getType() != dim.getType())
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
Expand Down Expand Up @@ -1319,8 +1319,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
assert(isa<Attribute>(maybeConstant) &&
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
Value constantVal = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
Value constantVal = arith::ConstantIndexOp::create(
rewriter, loc,
llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
Expand Down Expand Up @@ -2548,8 +2549,9 @@ struct CollapseShapeOpMemRefCastFolder
rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
} else {
Value newOp = rewriter.create<CollapseShapeOp>(
op->getLoc(), cast.getSource(), op.getReassociationIndices());
Value newOp =
CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
op.getReassociationIndices());
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
return success();
Expand Down Expand Up @@ -3006,15 +3008,15 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
Value offset =
op.isDynamicOffset(idx)
? op.getDynamicOffset(idx)
: b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
: arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
Value size =
op.isDynamicSize(idx)
? op.getDynamicSize(idx)
: b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
: arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
Value stride =
op.isDynamicStride(idx)
? op.getDynamicStride(idx)
: b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
: arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
Expand Down Expand Up @@ -3173,8 +3175,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
if (!resultType)
return failure();

Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.getSource(),
Value newSubView = SubViewOp::create(
rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
subViewOp.getStaticStrides());
Expand Down Expand Up @@ -3495,9 +3497,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
return failure();

// Create new ViewOp.
auto newViewOp = rewriter.create<ViewOp>(
viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
viewOp.getByteShift(), newOperands);
auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
viewOp.getOperand(0), viewOp.getByteShift(),
newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
Type resultType = alloca.getResult().getType();
OpBuilder builder(rewriter.getContext());
// TODO: Add a better builder for this.
globalOp = builder.create<memref::GlobalOp>(
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
globalOp = memref::GlobalOp::create(
builder, loc, StringAttr::get(ctx, "alloca"),
StringAttr::get(ctx, "private"), TypeAttr::get(resultType),
Attribute{}, UnitAttr{}, IntegerAttr{});
symbolTable.insert(globalOp);
}

Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ struct DefaultAllocationInterface
DefaultAllocationInterface, memref::AllocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value alloc) {
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
.getOperation();
}
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc)
.getResult();
}
static ::mlir::HoistingKind getHoistingKind() {
Expand All @@ -35,8 +35,9 @@ struct DefaultAllocationInterface
static ::std::optional<::mlir::Operation *>
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
Operation *definingOp = alloc.getDefiningOp();
return builder.create<memref::AllocaOp>(
definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
return memref::AllocaOp::create(
builder, definingOp->getLoc(),
cast<MemRefType>(definingOp->getResultTypes()[0]),
definingOp->getOperands(), definingOp->getAttrs());
}
};
Expand All @@ -52,7 +53,7 @@ struct DefaultReallocationInterface
DefaultAllocationInterface, memref::ReallocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value realloc) {
return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
return memref::DeallocOp::create(builder, realloc.getLoc(), realloc)
.getOperation();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}

AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
Value result = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), map, affineApplyOperands);
Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
affineApplyOperands);
offsets.push_back(result);
}
}
Expand Down
62 changes: 30 additions & 32 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
}

/// When writing a subbyte size, masked bitwise operations are used to only
Expand All @@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
auto dstIntegerType = builder.getIntegerType(dstBits);
auto maskRightAlignedAttr =
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
Value maskRightAligned = builder.create<arith::ConstantOp>(
loc, dstIntegerType, maskRightAlignedAttr);
Value maskRightAligned = arith::ConstantOp::create(
builder, loc, dstIntegerType, maskRightAlignedAttr);
Value writeMaskInverse =
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
Value flipVal =
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
}

/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
Expand All @@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
const SmallVector<OpFoldResult> &indices,
Value memref) {
auto stridedMetadata =
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
memref::ExtractStridedMetadataOp::create(builder, loc, memref);
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
Expand Down Expand Up @@ -298,24 +298,24 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Special case 0-rank memref loads.
Value bitsLoad;
if (convertedType.getRank() == 0) {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
ValueRange{});
} else {
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());

Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
Value newLoad = memref::LoadOp::create(
rewriter, loc, adaptor.getMemref(),
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));

// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
srcBits, dstBits, rewriter);
bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
}

// Get the corresponding bits. If the arith computation bitwidth equals
Expand All @@ -331,17 +331,17 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
: IntegerType::get(rewriter.getContext(),
resultTy.getIntOrFloatBitWidth());
if (conversionTy == convertedElementType) {
auto mask = rewriter.create<arith::ConstantOp>(
loc, convertedElementType,
auto mask = arith::ConstantOp::create(
rewriter, loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));

result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
} else {
result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
}

if (conversionTy != resultTy) {
result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
}

rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -428,20 +428,20 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Pad the input value with 0s on the left.
Value input = adaptor.getValue();
if (!input.getType().isInteger()) {
input = rewriter.create<arith::BitcastOp>(
loc,
input = arith::BitcastOp::create(
rewriter, loc,
IntegerType::get(rewriter.getContext(),
input.getType().getIntOrFloatBitWidth()),
input);
}
Value extendedInput =
rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);

// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
ValueRange{});
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
ValueRange{});
rewriter.eraseOp(op);
return success();
}
Expand All @@ -456,16 +456,14 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
dstBits, bitwidthOffset, rewriter);
// Align the value to write with the destination bits
Value alignedVal =
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);

// Clear destination bits
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(),
storeIndices);
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(), storeIndices);
// Write srcs bits to destination
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(),
storeIndices);
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(), storeIndices);
rewriter.eraseOp(op);
return success();
}
Expand Down Expand Up @@ -525,8 +523,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}

// Transform the offsets, sizes and strides according to the emulation.
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, subViewOp.getViewSource());
auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
rewriter, loc, subViewOp.getViewSource());

OpFoldResult linearizedIndices;
auto strides = stridedMetadata.getConstifiedMixedStrides();
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
if (op.getType().isDynamicDim(i)) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
Value index = arith::ConstantIndexOp::create(rewriter, loc, i);
size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
if (!isa<IndexType>(size.getType()))
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
size = arith::IndexCastOp::create(rewriter, loc,
rewriter.getIndexType(), size);
sizes[i] = size;
} else {
auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
sizes[i] = sizeAttr;
}
if (stride)
Expand All @@ -66,10 +66,11 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {

if (i > 0) {
if (stride) {
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
stride = arith::MulIOp::create(rewriter, loc, stride, size);
} else if (op.getType().isDynamicDim(i)) {
stride = rewriter.create<arith::MulIOp>(
loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
stride = arith::MulIOp::create(
rewriter, loc,
arith::ConstantIndexOp::create(rewriter, loc, staticStride),
size);
} else {
staticStride *= op.getType().getDimSize(i);
Expand Down
Loading