-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][NFC] update mlir/Dialect
create APIs (19/n)
#149926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][NFC] update mlir/Dialect
create APIs (19/n)
#149926
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
See llvm#147168 for more info.
36bf788
to
06607e0
Compare
@llvm/pr-subscribers-mlir-index @llvm/pr-subscribers-mlir-quant Author: Maksim Levental (makslevental) ChangesSee #147168 for more info. Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff 17 Files Affected:
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
- return b.create<BoolConstantOp>(loc, type, boolValue);
+ return BoolConstantOp::create(b, loc, type, boolValue);
}
// Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
- return b.create<ConstantOp>(loc, indexValue);
+ return ConstantOp::create(b, loc, indexValue);
}
return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
index::CmpOp newCmp;
if (rhsIsZero)
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getLhs(), subOp.getRhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getLhs(), subOp.getRhs());
else
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getRhs(), subOp.getLhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getRhs(), subOp.getLhs());
rewriter.replaceOp(op, newCmp);
return success();
}
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
auto loc = globalStoreOp.getLoc();
- auto targetMemref = rewriter.create<memref::GetGlobalOp>(
- loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+ auto targetMemref = memref::GetGlobalOp::create(
+ rewriter, loc, memrefType,
+ globalStoreOp.getGlobalAttr().getLeafReference());
auto sourceMemref =
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
if (!isa<IntegerAttr>(dltiAttr.value()))
return op->emitError()
<< "Expected an integer attribute for MPI:comm_world_rank";
- Value res = b.create<arith::ConstantIndexOp>(
- op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ Value res = arith::ConstantIndexOp::create(
+ b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
if (Value retVal = op.getRetval())
b.replaceOp(op, {retVal, res});
else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
return value;
};
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square =
- rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+ arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
return success();
}
// Replace `pow(x, -1.0)` with `1.0 / x`.
if (isExponentValue(-1.0)) {
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+ Value one = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
if (isExponentValue(0.75)) {
- Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
- Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+ Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+ Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
ValueRange{powHalf, powQuarter});
return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+ return vector::BroadcastOp::create(rewriter, loc, vec, value);
return value;
};
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(opType, 1.0));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(opType, 1.0));
else
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(opType, 1));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(opType, 1));
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
- result = rewriter.create<MulOpTy>(loc, result, base);
+ result = MulOpTy::create(rewriter, loc, result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative)
- result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+ result = DivOpTy::create(rewriter, loc, bcast(one), result);
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
Type i64Ty = b.getI64Type();
if (auto shapedTy = dyn_cast<ShapedType>(opType))
i64Ty = shapedTy.clone(i64Ty);
- Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
- Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+ Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
// The truncation does not preserve the sign when the truncated
// value is -0. So here the sign is copied again.
- return b.create<math::CopySignOp>(fpFixedConvert, operand);
+ return math::CopySignOp::create(b, fpFixedConvert, operand);
}
// sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value sub = b.create<arith::SubFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value sub = arith::SubFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(sub, half);
+ Value res = arith::MulFOp::create(b, sub, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value add = b.create<arith::AddFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value add = arith::AddFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(add, half);
+ Value res = arith::MulFOp::create(b, add, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value isNegative = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+ Value isNegative = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
Value isNegativeFloat =
- rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+ arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
Value isNegativeTimesNegTwo =
- rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
- Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+ arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+ Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
// Normalize input to positive value: y = sign(x) * x
- Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+ Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
// Decompose on normalized input
- Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+ Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+ Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+ Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+ Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+ Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
// Multiply result by sign(x) to retain signs from negative inputs
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type type = operand.getType();
- Value sin = b.create<math::SinOp>(type, operand);
- Value cos = b.create<math::CosOp>(type, operand);
- Value div = b.create<arith::DivFOp>(type, sin, cos);
+ Value sin = math::SinOp::create(b, type, operand);
+ Value cos = math::CosOp::create(b, type, operand);
+ Value div = arith::DivFOp::create(b, type, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, one);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, one);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
Type opType = operand.getType();
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, negOne);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, negOne);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value add = b.create<arith::AddFOp>(operand, one);
- Value neg = b.create<arith::NegFOp>(operand);
- Value sub = b.create<arith::AddFOp>(neg, one);
- Value div = b.create<arith::DivFOp>(add, sub);
- Value log = b.create<math::LogOp>(div);
+ Value add = arith::AddFOp::create(b, operand, one);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value sub = arith::AddFOp::create(b, neg, one);
+ Value div = arith::DivFOp::create(b, add, sub);
+ Value log = math::LogOp::create(b, div);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(log, half);
+ Value res = arith::MulFOp::create(b, log, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
- Value add = b.create<arith::AddFOp>(type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+ Value add = arith::AddFOp::create(b, type, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
- fpFixedConvert);
- Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+ Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+ fpFixedConvert);
+ Value incrValue =
+ arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+ Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
rewriter.replaceOp(op, ret);
return success();
}
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
auto convertFPowItoPowf = [&]() -> LogicalResult {
Value castPowerToFp =
- rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
- Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
- castPowerToFp);
+ arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+ Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+ castPowerToFp);
rewriter.replaceOp(op, res);
return success();
};
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = b.create<arith::MulFOp>(baseType, base, res);
+ res = arith::MulFOp::create(b, baseType, base, res);
absPower >>= 1;
- base = b.create<arith::MulFOp>(baseType, base, base);
+ base = arith::MulFOp::create(b, baseType, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
- res = b.create<arith::DivFOp>(baseType, one, res);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+ res = arith::DivFOp::create(b, baseType, one, res);
res =
- b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
- res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir-nvgpu Author: Maksim Levental (makslevental) ChangesSee #147168 for more info. Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff 17 Files Affected:
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
- return b.create<BoolConstantOp>(loc, type, boolValue);
+ return BoolConstantOp::create(b, loc, type, boolValue);
}
// Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
- return b.create<ConstantOp>(loc, indexValue);
+ return ConstantOp::create(b, loc, indexValue);
}
return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
index::CmpOp newCmp;
if (rhsIsZero)
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getLhs(), subOp.getRhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getLhs(), subOp.getRhs());
else
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getRhs(), subOp.getLhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getRhs(), subOp.getLhs());
rewriter.replaceOp(op, newCmp);
return success();
}
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
auto loc = globalStoreOp.getLoc();
- auto targetMemref = rewriter.create<memref::GetGlobalOp>(
- loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+ auto targetMemref = memref::GetGlobalOp::create(
+ rewriter, loc, memrefType,
+ globalStoreOp.getGlobalAttr().getLeafReference());
auto sourceMemref =
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
if (!isa<IntegerAttr>(dltiAttr.value()))
return op->emitError()
<< "Expected an integer attribute for MPI:comm_world_rank";
- Value res = b.create<arith::ConstantIndexOp>(
- op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ Value res = arith::ConstantIndexOp::create(
+ b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
if (Value retVal = op.getRetval())
b.replaceOp(op, {retVal, res});
else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
return value;
};
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square =
- rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+ arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
return success();
}
// Replace `pow(x, -1.0)` with `1.0 / x`.
if (isExponentValue(-1.0)) {
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+ Value one = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
if (isExponentValue(0.75)) {
- Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
- Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+ Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+ Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
ValueRange{powHalf, powQuarter});
return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+ return vector::BroadcastOp::create(rewriter, loc, vec, value);
return value;
};
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(opType, 1.0));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(opType, 1.0));
else
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(opType, 1));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(opType, 1));
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
- result = rewriter.create<MulOpTy>(loc, result, base);
+ result = MulOpTy::create(rewriter, loc, result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative)
- result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+ result = DivOpTy::create(rewriter, loc, bcast(one), result);
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
Type i64Ty = b.getI64Type();
if (auto shapedTy = dyn_cast<ShapedType>(opType))
i64Ty = shapedTy.clone(i64Ty);
- Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
- Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+ Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
// The truncation does not preserve the sign when the truncated
// value is -0. So here the sign is copied again.
- return b.create<math::CopySignOp>(fpFixedConvert, operand);
+ return math::CopySignOp::create(b, fpFixedConvert, operand);
}
// sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value sub = b.create<arith::SubFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value sub = arith::SubFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(sub, half);
+ Value res = arith::MulFOp::create(b, sub, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value add = b.create<arith::AddFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value add = arith::AddFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(add, half);
+ Value res = arith::MulFOp::create(b, add, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value isNegative = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+ Value isNegative = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
Value isNegativeFloat =
- rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+ arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
Value isNegativeTimesNegTwo =
- rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
- Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+ arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+ Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
// Normalize input to positive value: y = sign(x) * x
- Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+ Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
// Decompose on normalized input
- Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+ Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+ Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+ Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+ Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+ Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
// Multiply result by sign(x) to retain signs from negative inputs
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type type = operand.getType();
- Value sin = b.create<math::SinOp>(type, operand);
- Value cos = b.create<math::CosOp>(type, operand);
- Value div = b.create<arith::DivFOp>(type, sin, cos);
+ Value sin = math::SinOp::create(b, type, operand);
+ Value cos = math::CosOp::create(b, type, operand);
+ Value div = arith::DivFOp::create(b, type, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, one);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, one);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
Type opType = operand.getType();
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, negOne);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, negOne);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value add = b.create<arith::AddFOp>(operand, one);
- Value neg = b.create<arith::NegFOp>(operand);
- Value sub = b.create<arith::AddFOp>(neg, one);
- Value div = b.create<arith::DivFOp>(add, sub);
- Value log = b.create<math::LogOp>(div);
+ Value add = arith::AddFOp::create(b, operand, one);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value sub = arith::AddFOp::create(b, neg, one);
+ Value div = arith::DivFOp::create(b, add, sub);
+ Value log = math::LogOp::create(b, div);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(log, half);
+ Value res = arith::MulFOp::create(b, log, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
- Value add = b.create<arith::AddFOp>(type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+ Value add = arith::AddFOp::create(b, type, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
- fpFixedConvert);
- Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+ Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+ fpFixedConvert);
+ Value incrValue =
+ arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+ Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
rewriter.replaceOp(op, ret);
return success();
}
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
auto convertFPowItoPowf = [&]() -> LogicalResult {
Value castPowerToFp =
- rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
- Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
- castPowerToFp);
+ arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+ Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+ castPowerToFp);
rewriter.replaceOp(op, res);
return success();
};
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = b.create<arith::MulFOp>(baseType, base, res);
+ res = arith::MulFOp::create(b, baseType, base, res);
absPower >>= 1;
- base = b.create<arith::MulFOp>(baseType, base, base);
+ base = arith::MulFOp::create(b, baseType, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
- res = b.create<arith::DivFOp>(baseType, one, res);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+ res = arith::DivFOp::create(b, baseType, one, res);
res =
- b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
- res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesSee #147168 for more info. Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff 17 Files Affected:
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
- return b.create<BoolConstantOp>(loc, type, boolValue);
+ return BoolConstantOp::create(b, loc, type, boolValue);
}
// Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
- return b.create<ConstantOp>(loc, indexValue);
+ return ConstantOp::create(b, loc, indexValue);
}
return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
index::CmpOp newCmp;
if (rhsIsZero)
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getLhs(), subOp.getRhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getLhs(), subOp.getRhs());
else
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getRhs(), subOp.getLhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getRhs(), subOp.getLhs());
rewriter.replaceOp(op, newCmp);
return success();
}
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
auto loc = globalStoreOp.getLoc();
- auto targetMemref = rewriter.create<memref::GetGlobalOp>(
- loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+ auto targetMemref = memref::GetGlobalOp::create(
+ rewriter, loc, memrefType,
+ globalStoreOp.getGlobalAttr().getLeafReference());
auto sourceMemref =
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
if (!isa<IntegerAttr>(dltiAttr.value()))
return op->emitError()
<< "Expected an integer attribute for MPI:comm_world_rank";
- Value res = b.create<arith::ConstantIndexOp>(
- op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ Value res = arith::ConstantIndexOp::create(
+ b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
if (Value retVal = op.getRetval())
b.replaceOp(op, {retVal, res});
else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
return value;
};
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square =
- rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+ arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
return success();
}
// Replace `pow(x, -1.0)` with `1.0 / x`.
if (isExponentValue(-1.0)) {
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+ Value one = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
if (isExponentValue(0.75)) {
- Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
- Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+ Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+ Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
ValueRange{powHalf, powQuarter});
return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+ return vector::BroadcastOp::create(rewriter, loc, vec, value);
return value;
};
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(opType, 1.0));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(opType, 1.0));
else
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(opType, 1));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(opType, 1));
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
- result = rewriter.create<MulOpTy>(loc, result, base);
+ result = MulOpTy::create(rewriter, loc, result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative)
- result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+ result = DivOpTy::create(rewriter, loc, bcast(one), result);
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
Type i64Ty = b.getI64Type();
if (auto shapedTy = dyn_cast<ShapedType>(opType))
i64Ty = shapedTy.clone(i64Ty);
- Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
- Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+ Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
// The truncation does not preserve the sign when the truncated
// value is -0. So here the sign is copied again.
- return b.create<math::CopySignOp>(fpFixedConvert, operand);
+ return math::CopySignOp::create(b, fpFixedConvert, operand);
}
// sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value sub = b.create<arith::SubFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value sub = arith::SubFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(sub, half);
+ Value res = arith::MulFOp::create(b, sub, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value add = b.create<arith::AddFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value add = arith::AddFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(add, half);
+ Value res = arith::MulFOp::create(b, add, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value isNegative = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+ Value isNegative = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
Value isNegativeFloat =
- rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+ arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
Value isNegativeTimesNegTwo =
- rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
- Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+ arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+ Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
// Normalize input to positive value: y = sign(x) * x
- Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+ Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
// Decompose on normalized input
- Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+ Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+ Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+ Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+ Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+ Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
// Multiply result by sign(x) to retain signs from negative inputs
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type type = operand.getType();
- Value sin = b.create<math::SinOp>(type, operand);
- Value cos = b.create<math::CosOp>(type, operand);
- Value div = b.create<arith::DivFOp>(type, sin, cos);
+ Value sin = math::SinOp::create(b, type, operand);
+ Value cos = math::CosOp::create(b, type, operand);
+ Value div = arith::DivFOp::create(b, type, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, one);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, one);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
Type opType = operand.getType();
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, negOne);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, negOne);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value add = b.create<arith::AddFOp>(operand, one);
- Value neg = b.create<arith::NegFOp>(operand);
- Value sub = b.create<arith::AddFOp>(neg, one);
- Value div = b.create<arith::DivFOp>(add, sub);
- Value log = b.create<math::LogOp>(div);
+ Value add = arith::AddFOp::create(b, operand, one);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value sub = arith::AddFOp::create(b, neg, one);
+ Value div = arith::DivFOp::create(b, add, sub);
+ Value log = math::LogOp::create(b, div);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(log, half);
+ Value res = arith::MulFOp::create(b, log, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
- Value add = b.create<arith::AddFOp>(type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+ Value add = arith::AddFOp::create(b, type, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
- fpFixedConvert);
- Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+ Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+ fpFixedConvert);
+ Value incrValue =
+ arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+ Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
rewriter.replaceOp(op, ret);
return success();
}
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
auto convertFPowItoPowf = [&]() -> LogicalResult {
Value castPowerToFp =
- rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
- Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
- castPowerToFp);
+ arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+ Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+ castPowerToFp);
rewriter.replaceOp(op, res);
return success();
};
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = b.create<arith::MulFOp>(baseType, base, res);
+ res = arith::MulFOp::create(b, baseType, base, res);
absPower >>= 1;
- base = b.create<arith::MulFOp>(baseType, base, base);
+ base = arith::MulFOp::create(b, baseType, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
- res = b.create<arith::DivFOp>(baseType, one, res);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+ res = arith::DivFOp::create(b, baseType, one, res);
res =
- b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
- res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
- ...
[truncated]
|
Did you use a sed command for this, can you post it here for posterity? I'm curious why you chose to break this down into quite so many PRs (26?) , wouldn't it be more straightforward to run the regex on the whole project, once, and post a single PR? btw thanks a million for this change, I'm looking forward to not using the convoluted |
To make each chunk reviewable so that if the regex did something wrong we had a chance of catching it.
👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regex looks like it worked as expected, LGTM.
OpName::createOrFold
coming later?
Maybe? I haven't quite decided whether it's too much codegen at that point for a much less frequently used API. And under "shared object" build it wouldn't be DCEd by the linker so 🤷♂️ |
See #147168 for more info.