Skip to content

Commit b0312be

Browse files
authored
[mlir][NFC] update mlir/Dialect create APIs (19/n) (#149926)
See #147168 for more info.
1 parent dc87a14 commit b0312be

File tree

17 files changed

+641
-617
lines changed

17 files changed

+641
-617
lines changed

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
3636
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
3737
if (!type.isSignlessInteger(1))
3838
return nullptr;
39-
return b.create<BoolConstantOp>(loc, type, boolValue);
39+
return BoolConstantOp::create(b, loc, type, boolValue);
4040
}
4141

4242
// Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
4646
return nullptr;
4747
assert(indexValue.getValue().getBitWidth() ==
4848
IndexType::kInternalStorageBitWidth);
49-
return b.create<ConstantOp>(loc, indexValue);
49+
return ConstantOp::create(b, loc, indexValue);
5050
}
5151

5252
return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
715715

716716
index::CmpOp newCmp;
717717
if (rhsIsZero)
718-
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
719-
subOp.getLhs(), subOp.getRhs());
718+
newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
719+
subOp.getLhs(), subOp.getRhs());
720720
else
721-
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
722-
subOp.getRhs(), subOp.getLhs());
721+
newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
722+
subOp.getRhs(), subOp.getLhs());
723723
rewriter.replaceOp(op, newCmp);
724724
return success();
725725
}

mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
135135
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
136136

137137
auto loc = globalStoreOp.getLoc();
138-
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
139-
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
138+
auto targetMemref = memref::GetGlobalOp::create(
139+
rewriter, loc, memrefType,
140+
globalStoreOp.getGlobalAttr().getLeafReference());
140141

141142
auto sourceMemref =
142143
getBuffer(rewriter, globalStoreOp.getValue(), options, state);

mlir/lib/Dialect/MPI/IR/MPIOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
6060
if (!isa<IntegerAttr>(dltiAttr.value()))
6161
return op->emitError()
6262
<< "Expected an integer attribute for MPI:comm_world_rank";
63-
Value res = b.create<arith::ConstantIndexOp>(
64-
op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
63+
Value res = arith::ConstantIndexOp::create(
64+
b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
6565
if (Value retVal = op.getRetval())
6666
b.replaceOp(op, {retVal, res});
6767
else

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
746746
Attribute value, Type type,
747747
Location loc) {
748748
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
749-
return builder.create<ub::PoisonOp>(loc, type, poison);
749+
return ub::PoisonOp::create(builder, loc, type, poison);
750750

751751
return arith::ConstantOp::materialize(builder, value, type, loc);
752752
}

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
6565
// Maybe broadcasts scalar value into vector type compatible with `op`.
6666
auto bcast = [&](Value value) -> Value {
6767
if (auto vec = dyn_cast<VectorType>(op.getType()))
68-
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
68+
return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
6969
return value;
7070
};
7171

@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
8484
// Replace `pow(x, 3.0)` with `x * x * x`.
8585
if (isExponentValue(3.0)) {
8686
Value square =
87-
rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
87+
arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
8888
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
8989
return success();
9090
}
9191

9292
// Replace `pow(x, -1.0)` with `1.0 / x`.
9393
if (isExponentValue(-1.0)) {
94-
Value one = rewriter.create<arith::ConstantOp>(
95-
loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
94+
Value one = arith::ConstantOp::create(
95+
rewriter, loc,
96+
rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
9697
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
9798
return success();
9899
}
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
111112

112113
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113114
if (isExponentValue(0.75)) {
114-
Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115-
Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
115+
Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
116+
Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
116117
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
117118
ValueRange{powHalf, powQuarter});
118119
return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
168169
// Maybe broadcasts scalar value into vector type compatible with `op`.
169170
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
170171
if (auto vec = dyn_cast<VectorType>(op.getType()))
171-
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
172+
return vector::BroadcastOp::create(rewriter, loc, vec, value);
172173
return value;
173174
};
174175

175176
Value one;
176177
Type opType = getElementTypeOrSelf(op.getType());
177178
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178-
one = rewriter.create<arith::ConstantOp>(
179-
loc, rewriter.getFloatAttr(opType, 1.0));
179+
one = arith::ConstantOp::create(rewriter, loc,
180+
rewriter.getFloatAttr(opType, 1.0));
180181
else
181-
one = rewriter.create<arith::ConstantOp>(
182-
loc, rewriter.getIntegerAttr(opType, 1));
182+
one = arith::ConstantOp::create(rewriter, loc,
183+
rewriter.getIntegerAttr(opType, 1));
183184

184185
// Replace `[fi]powi(x, 0)` with `1`.
185186
if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
208209
// with:
209210
// (1 / x) * (1 / x) * (1 / x) * ...
210211
for (unsigned i = 1; i < exponentValue; ++i)
211-
result = rewriter.create<MulOpTy>(loc, result, base);
212+
result = MulOpTy::create(rewriter, loc, result, base);
212213

213214
// Inverse the base for negative exponent, i.e. for
214215
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
215216
if (exponentIsNegative)
216-
result = rewriter.create<DivOpTy>(loc, bcast(one), result);
217+
result = DivOpTy::create(rewriter, loc, bcast(one), result);
217218

218219
rewriter.replaceOp(op, result);
219220
return success();

0 commit comments

Comments
 (0)