Skip to content

Commit f904cdd

Browse files
authored
[mlir][NFC] update mlir/Dialect create APIs (24/n) (#149931)
See #147168 for more info.
1 parent 972ac59 commit f904cdd

26 files changed

+825
-776
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) {
124124
/// Default callback to build a region with a 'vector.yield' terminator with no
125125
/// arguments.
126126
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
127-
builder.create<vector::YieldOp>(loc);
127+
vector::YieldOp::create(builder, loc);
128128
}
129129

130130
// Helper for verifying combining kinds in contractions and reductions.
@@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction
596596
VectorType newMaskType =
597597
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
598598
dstVecType.getScalableDims());
599-
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
599+
mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
600600
}
601-
cast = rewriter.create<vector::ShapeCastOp>(
602-
loc, reductionOp.getDestType(), reductionOp.getSource());
601+
cast = vector::ShapeCastOp::create(
602+
rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
603603
} else {
604604
// This means we are reducing all the dimensions, and all reduction
605605
// dimensions are of size 1. So a simple extraction would do.
606606
if (mask)
607-
mask = rewriter.create<vector::ExtractOp>(loc, mask);
608-
cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
607+
mask = vector::ExtractOp::create(rewriter, loc, mask);
608+
cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
609609
}
610610

611611
Value result =
@@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
672672
switch (op) {
673673
case arith::AtomicRMWKind::addf:
674674
case arith::AtomicRMWKind::addi:
675-
return builder.create<vector::ReductionOp>(vector.getLoc(),
676-
CombiningKind::ADD, vector);
675+
return vector::ReductionOp::create(builder, vector.getLoc(),
676+
CombiningKind::ADD, vector);
677677
case arith::AtomicRMWKind::mulf:
678678
case arith::AtomicRMWKind::muli:
679-
return builder.create<vector::ReductionOp>(vector.getLoc(),
680-
CombiningKind::MUL, vector);
679+
return vector::ReductionOp::create(builder, vector.getLoc(),
680+
CombiningKind::MUL, vector);
681681
case arith::AtomicRMWKind::minimumf:
682-
return builder.create<vector::ReductionOp>(vector.getLoc(),
683-
CombiningKind::MINIMUMF, vector);
682+
return vector::ReductionOp::create(builder, vector.getLoc(),
683+
CombiningKind::MINIMUMF, vector);
684684
case arith::AtomicRMWKind::mins:
685-
return builder.create<vector::ReductionOp>(vector.getLoc(),
686-
CombiningKind::MINSI, vector);
685+
return vector::ReductionOp::create(builder, vector.getLoc(),
686+
CombiningKind::MINSI, vector);
687687
case arith::AtomicRMWKind::minu:
688-
return builder.create<vector::ReductionOp>(vector.getLoc(),
689-
CombiningKind::MINUI, vector);
688+
return vector::ReductionOp::create(builder, vector.getLoc(),
689+
CombiningKind::MINUI, vector);
690690
case arith::AtomicRMWKind::maximumf:
691-
return builder.create<vector::ReductionOp>(vector.getLoc(),
692-
CombiningKind::MAXIMUMF, vector);
691+
return vector::ReductionOp::create(builder, vector.getLoc(),
692+
CombiningKind::MAXIMUMF, vector);
693693
case arith::AtomicRMWKind::maxs:
694-
return builder.create<vector::ReductionOp>(vector.getLoc(),
695-
CombiningKind::MAXSI, vector);
694+
return vector::ReductionOp::create(builder, vector.getLoc(),
695+
CombiningKind::MAXSI, vector);
696696
case arith::AtomicRMWKind::maxu:
697-
return builder.create<vector::ReductionOp>(vector.getLoc(),
698-
CombiningKind::MAXUI, vector);
697+
return vector::ReductionOp::create(builder, vector.getLoc(),
698+
CombiningKind::MAXUI, vector);
699699
case arith::AtomicRMWKind::andi:
700-
return builder.create<vector::ReductionOp>(vector.getLoc(),
701-
CombiningKind::AND, vector);
700+
return vector::ReductionOp::create(builder, vector.getLoc(),
701+
CombiningKind::AND, vector);
702702
case arith::AtomicRMWKind::ori:
703-
return builder.create<vector::ReductionOp>(vector.getLoc(),
704-
CombiningKind::OR, vector);
703+
return vector::ReductionOp::create(builder, vector.getLoc(),
704+
CombiningKind::OR, vector);
705705
// TODO: Add remaining reduction operations.
706706
default:
707707
(void)emitOptionalError(loc, "Reduction operation type not supported");
@@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
740740

741741
Location loc = reductionOp.getLoc();
742742
if (mask)
743-
mask = rewriter.create<ExtractOp>(loc, mask);
744-
Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
743+
mask = ExtractOp::create(rewriter, loc, mask);
744+
Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
745745

746746
if (Value acc = reductionOp.getAcc())
747747
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final
41724172
// greater than the vector dim size.
41734173
IntegerAttr offsetAttr =
41744174
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4175-
Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
4175+
Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
41764176
Value sliceMaskDimSize =
4177-
rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
4177+
arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
41784178
sliceMaskDimSizes.push_back(sliceMaskDimSize);
41794179
}
41804180
// Add unchanged dimensions.
@@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final
42894289
sizes[i] = 1;
42904290
}
42914291
}
4292-
source = rewriter.create<ExtractStridedSliceOp>(
4293-
op->getLoc(), source, offsets, sizes,
4292+
source = ExtractStridedSliceOp::create(
4293+
rewriter, op->getLoc(), source, offsets, sizes,
42944294
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
42954295
}
42964296
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
@@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final
43824382

43834383
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
43844384
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4385-
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4386-
extractOffsets);
4385+
Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4386+
extractOffsets);
43874387
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
43884388
return success();
43894389
}
@@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
44134413

44144414
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
44154415
if (!padding)
4416-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4416+
padding = ub::PoisonOp::create(builder, result.location, elemType);
44174417
build(builder, result, vectorType, source, indices, permutationMapAttr,
44184418
*padding, /*mask=*/Value(), inBoundsAttr);
44194419
}
@@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
44314431
SmallVector<bool>(vectorType.getRank(), false));
44324432
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
44334433
if (!padding)
4434-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4434+
padding = ub::PoisonOp::create(builder, result.location, elemType);
44354435
build(builder, result, vectorType, source, indices, *padding,
44364436
permutationMapAttr, inBoundsAttr);
44374437
}
@@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
44504450
SmallVector<bool>(vectorType.getRank(), false));
44514451
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
44524452
if (!padding)
4453-
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4453+
padding = ub::PoisonOp::create(builder, result.location, elemType);
44544454
build(builder, result, vectorType, source, indices, permutationMapAttr,
44554455
*padding,
44564456
/*mask=*/Value(), inBoundsAttr);
@@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast
49754975
VectorType broadcastedType = VectorType::get(
49764976
broadcastShape, defWrite.getVectorType().getElementType(),
49774977
broadcastScalableFlags);
4978-
vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4978+
vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
49794979
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
49804980
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
49814981
transposePerm);
@@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite
54535453
// Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
54545454
// Set all in_bounds to false and let the folder infer them.
54555455
SmallVector<bool> newInBounds(vectorShape.size(), false);
5456-
auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5457-
extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5458-
insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5459-
insertOp.getMixedStrides());
5460-
auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5461-
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5462-
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5456+
auto newExtractOp = tensor::ExtractSliceOp::create(
5457+
rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5458+
insertOp.getDest(), insertOp.getMixedOffsets(),
5459+
insertOp.getMixedSizes(), insertOp.getMixedStrides());
5460+
auto newTransferWriteOp = TransferWriteOp::create(
5461+
rewriter, transferOp.getLoc(), transferOp.getVector(),
5462+
newExtractOp.getResult(), transferOp.getIndices(),
5463+
transferOp.getPermutationMapAttr(),
54635464
rewriter.getBoolArrayAttr(newInBounds));
54645465
rewriter.modifyOpInPlace(insertOp, [&]() {
54655466
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
@@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
69836984
OpBuilder opBuilder(builder.getContext());
69846985
Operation *maskedOp = &block.front();
69856986
opBuilder.setInsertionPointToEnd(&block);
6986-
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6987+
vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
69876988
}
69886989

69896990
LogicalResult MaskOp::verify() {
@@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
73187319
// Create a block and move the op to that block.
73197320
insBlock->getOperations().splice(
73207321
insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7321-
builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
7322+
YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
73227323
}
73237324

73247325
/// Creates a vector.mask operation around a maskable operation. Returns the
@@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
73307331
if (!mask)
73317332
return maskableOp;
73327333
if (passthru)
7333-
return builder.create<MaskOp>(maskableOp->getLoc(),
7334-
maskableOp->getResultTypes(), mask, passthru,
7335-
maskableOp, createMaskOpRegion);
7336-
return builder.create<MaskOp>(maskableOp->getLoc(),
7337-
maskableOp->getResultTypes(), mask, maskableOp,
7338-
createMaskOpRegion);
7334+
return MaskOp::create(builder, maskableOp->getLoc(),
7335+
maskableOp->getResultTypes(), mask, passthru,
7336+
maskableOp, createMaskOpRegion);
7337+
return MaskOp::create(builder, maskableOp->getLoc(),
7338+
maskableOp->getResultTypes(), mask, maskableOp,
7339+
createMaskOpRegion);
73397340
}
73407341

73417342
/// Creates a vector select operation that picks values from `newValue` or
@@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
73507351
if (!mask)
73517352
return newValue;
73527353

7353-
return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
7354-
mask, newValue, passthru);
7354+
return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7355+
mask, newValue, passthru);
73557356
}
73567357

73577358
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ struct TransferWriteOpInterface
116116
getBuffer(rewriter, writeOp.getBase(), options, state);
117117
if (failed(resultBuffer))
118118
return failure();
119-
rewriter.create<vector::TransferWriteOp>(
120-
writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
119+
vector::TransferWriteOp::create(
120+
rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
121121
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
122122
writeOp.getMask(), writeOp.getInBoundsAttr());
123123
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
@@ -241,8 +241,9 @@ struct MaskOpInterface
241241
// Create a new vector.mask op.
242242
ValueRange newYieldedValuesRange(newYieldedValues);
243243
TypeRange newResultTypes(newYieldedValuesRange);
244-
auto newOp = rewriter.create<vector::MaskOp>(
245-
op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
244+
auto newOp = vector::MaskOp::create(
245+
rewriter, op->getLoc(), newResultTypes, maskOp.getMask(),
246+
maskOp.getPassthru(),
246247
/*maskableOp=*/nullptr,
247248
/*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
248249
newOp.getRegion().takeBody(maskOp.getMaskRegion());

mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
6464
VectorType::get(shape, resultType.getElementType(), scalableDims);
6565

6666
Location loc = op.getLoc();
67-
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
67+
Value result = ub::PoisonOp::create(rewriter, loc, resultType);
6868
for (auto position : *unrollIterator) {
6969
Value extract =
70-
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
70+
vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
7171
Value bitcast =
72-
rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
72+
vector::BitCastOp::create(rewriter, loc, bitcastResType, extract);
7373
result =
74-
rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
74+
vector::InsertOp::create(rewriter, loc, bitcast, result, position);
7575
}
7676

7777
rewriter.replaceOp(op, result);

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
5252

5353
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
5454
if (srcRank <= 1 && dstRank == 1) {
55-
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
55+
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
5656
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
5757
return success();
5858
}
@@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
7070
// Duplication.
7171
VectorType resType = VectorType::Builder(dstType).dropDim(0);
7272
Value bcst =
73-
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
74-
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
73+
vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
74+
Value result = ub::PoisonOp::create(rewriter, loc, dstType);
7575
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
76-
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
76+
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
7777
rewriter.replaceOp(op, result);
7878
return success();
7979
}
@@ -111,23 +111,23 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
111111
VectorType resType =
112112
VectorType::get(dstType.getShape().drop_front(), eltType,
113113
dstType.getScalableDims().drop_front());
114-
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
114+
Value result = ub::PoisonOp::create(rewriter, loc, dstType);
115115
if (m == 0) {
116116
// Stetch at start.
117-
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
118-
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
117+
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
118+
Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
119119
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
120-
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
120+
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
121121
} else {
122122
// Stetch not at start.
123123
if (dstType.getScalableDims()[0]) {
124124
// TODO: For scalable vectors we should emit an scf.for loop.
125125
return failure();
126126
}
127127
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
128-
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
129-
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
130-
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
128+
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
129+
Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
130+
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
131131
}
132132
}
133133
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)