Skip to content

[mlir][NFC] update mlir/Dialect create APIs (19/n) #149926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025

Conversation

makslevental
Copy link
Contributor

See #147168 for more info.

Copy link

github-actions bot commented Jul 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir-index
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-quant

Author: Maksim Levental (makslevental)

Changes

See #147168 for more info.


Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff

17 Files Affected:

  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+3-2)
  • (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+14-13)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+131-130)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+305-296)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+13-13)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+15-11)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+27-24)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+47-46)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp (+14-13)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (+49-47)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2)
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
   if (auto boolValue = dyn_cast<BoolAttr>(value)) {
     if (!type.isSignlessInteger(1))
       return nullptr;
-    return b.create<BoolConstantOp>(loc, type, boolValue);
+    return BoolConstantOp::create(b, loc, type, boolValue);
   }
 
   // Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
       return nullptr;
     assert(indexValue.getValue().getBitWidth() ==
            IndexType::kInternalStorageBitWidth);
-    return b.create<ConstantOp>(loc, indexValue);
+    return ConstantOp::create(b, loc, indexValue);
   }
 
   return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
 
   index::CmpOp newCmp;
   if (rhsIsZero)
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getLhs(), subOp.getRhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getLhs(), subOp.getRhs());
   else
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getRhs(), subOp.getLhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getRhs(), subOp.getLhs());
   rewriter.replaceOp(op, newCmp);
   return success();
 }
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
 
     auto loc = globalStoreOp.getLoc();
-    auto targetMemref = rewriter.create<memref::GetGlobalOp>(
-        loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+    auto targetMemref = memref::GetGlobalOp::create(
+        rewriter, loc, memrefType,
+        globalStoreOp.getGlobalAttr().getLeafReference());
 
     auto sourceMemref =
         getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
     if (!isa<IntegerAttr>(dltiAttr.value()))
       return op->emitError()
              << "Expected an integer attribute for MPI:comm_world_rank";
-    Value res = b.create<arith::ConstantIndexOp>(
-        op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+    Value res = arith::ConstantIndexOp::create(
+        b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
     if (Value retVal = op.getRetval())
       b.replaceOp(op, {retVal, res});
     else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poison);
+    return ub::PoisonOp::create(builder, loc, type, poison);
 
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+      return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
     return value;
   };
 
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
     Value square =
-        rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+        arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
     return success();
   }
 
   // Replace `pow(x, -1.0)` with `1.0 / x`.
   if (isExponentValue(-1.0)) {
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+    Value one = arith::ConstantOp::create(
+        rewriter, loc,
+        rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
     return success();
   }
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
-    Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+    Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+    Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
                                                ValueRange{powHalf, powQuarter});
     return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+      return vector::BroadcastOp::create(rewriter, loc, vec, value);
     return value;
   };
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
   if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(opType, 1.0));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getFloatAttr(opType, 1.0));
   else
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(opType, 1));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getIntegerAttr(opType, 1));
 
   // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   //     with:
   //       (1 / x) * (1 / x) * (1 / x) * ...
   for (unsigned i = 1; i < exponentValue; ++i)
-    result = rewriter.create<MulOpTy>(loc, result, base);
+    result = MulOpTy::create(rewriter, loc, result, base);
 
   // Inverse the base for negative exponent, i.e. for
   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
   if (exponentIsNegative)
-    result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+    result = DivOpTy::create(rewriter, loc, bcast(one), result);
 
   rewriter.replaceOp(op, result);
   return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
                 APFloat::rmNearestTiesToEven, &losesInfo);
   auto attr = b.getFloatAttr(eltType, value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
                             OpBuilder &b) {
   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
   Type i64Ty = b.getI64Type();
   if (auto shapedTy = dyn_cast<ShapedType>(opType))
     i64Ty = shapedTy.clone(i64Ty);
-  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
-  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+  Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+  Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
   // The truncation does not preserve the sign when the truncated
   // value is -0. So here the sign is copied again.
-  return b.create<math::CopySignOp>(fpFixedConvert, operand);
+  return math::CopySignOp::create(b, fpFixedConvert, operand);
 }
 
 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value sub = b.create<arith::SubFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value sub = arith::SubFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(sub, half);
+  Value res = arith::MulFOp::create(b, sub, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value add = b.create<arith::AddFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value add = arith::AddFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(add, half);
+  Value res = arith::MulFOp::create(b, add, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
 
   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
-  Value isNegative = rewriter.create<arith::CmpFOp>(
-      loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+  Value isNegative = arith::CmpFOp::create(
+      rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
   Value isNegativeFloat =
-      rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+      arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
   Value isNegativeTimesNegTwo =
-      rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
-  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+      arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+  Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
 
   // Normalize input to positive value: y = sign(x) * x
-  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+  Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
 
   // Decompose on normalized input
-  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
-  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
-  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
-  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
-  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+  Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+  Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+  Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+  Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+  Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
 
   // Multiply result by sign(x) to retain signs from negative inputs
   rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type type = operand.getType();
-  Value sin = b.create<math::SinOp>(type, operand);
-  Value cos = b.create<math::CosOp>(type, operand);
-  Value div = b.create<arith::DivFOp>(type, sin, cos);
+  Value sin = math::SinOp::create(b, type, operand);
+  Value cos = math::CosOp::create(b, type, operand);
+  Value div = arith::DivFOp::create(b, type, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, one);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, one);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
   Type opType = operand.getType();
 
   Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, negOne);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value add = b.create<arith::AddFOp>(operand, one);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value sub = b.create<arith::AddFOp>(neg, one);
-  Value div = b.create<arith::DivFOp>(add, sub);
-  Value log = b.create<math::LogOp>(div);
+  Value add = arith::AddFOp::create(b, operand, one);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value sub = arith::AddFOp::create(b, neg, one);
+  Value div = arith::DivFOp::create(b, add, sub);
+  Value log = math::LogOp::create(b, div);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(log, half);
+  Value res = arith::MulFOp::create(b, log, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
-  Value add = b.create<arith::AddFOp>(type, mult, operandC);
+  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+  Value add = arith::AddFOp::create(b, type, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
 
-  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
-                                          fpFixedConvert);
-  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+  Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+                                        fpFixedConvert);
+  Value incrValue =
+      arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
   rewriter.replaceOp(op, ret);
   return success();
 }
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   auto convertFPowItoPowf = [&]() -> LogicalResult {
     Value castPowerToFp =
-        rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
-    Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
-                                              castPowerToFp);
+        arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+    Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+                                     castPowerToFp);
     rewriter.replaceOp(op, res);
     return success();
   };
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = b.create<arith::MulFOp>(baseType, base, res);
+      res = arith::MulFOp::create(b, baseType, base, res);
     absPower >>= 1;
-    base = b.create<arith::MulFOp>(baseType, base, base);
+    base = arith::MulFOp::create(b, baseType, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         createFloatConst(op->getLoc(), baseType,
                          APFloat::getInf(sem, /*Negative=*/true), rewriter);
     Value zeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
-    res = b.create<arith::DivFOp>(baseType, one, res);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+    res = arith::DivFOp::create(b, baseType, one, res);
     res =
-        b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
-    res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
-                 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir-nvgpu

Author: Maksim Levental (makslevental)

Changes

See #147168 for more info.


Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff

17 Files Affected:

  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+3-2)
  • (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+14-13)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+131-130)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+305-296)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+13-13)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+15-11)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+27-24)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+47-46)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp (+14-13)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (+49-47)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2)
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
   if (auto boolValue = dyn_cast<BoolAttr>(value)) {
     if (!type.isSignlessInteger(1))
       return nullptr;
-    return b.create<BoolConstantOp>(loc, type, boolValue);
+    return BoolConstantOp::create(b, loc, type, boolValue);
   }
 
   // Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
       return nullptr;
     assert(indexValue.getValue().getBitWidth() ==
            IndexType::kInternalStorageBitWidth);
-    return b.create<ConstantOp>(loc, indexValue);
+    return ConstantOp::create(b, loc, indexValue);
   }
 
   return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
 
   index::CmpOp newCmp;
   if (rhsIsZero)
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getLhs(), subOp.getRhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getLhs(), subOp.getRhs());
   else
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getRhs(), subOp.getLhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getRhs(), subOp.getLhs());
   rewriter.replaceOp(op, newCmp);
   return success();
 }
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
 
     auto loc = globalStoreOp.getLoc();
-    auto targetMemref = rewriter.create<memref::GetGlobalOp>(
-        loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+    auto targetMemref = memref::GetGlobalOp::create(
+        rewriter, loc, memrefType,
+        globalStoreOp.getGlobalAttr().getLeafReference());
 
     auto sourceMemref =
         getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
     if (!isa<IntegerAttr>(dltiAttr.value()))
       return op->emitError()
              << "Expected an integer attribute for MPI:comm_world_rank";
-    Value res = b.create<arith::ConstantIndexOp>(
-        op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+    Value res = arith::ConstantIndexOp::create(
+        b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
     if (Value retVal = op.getRetval())
       b.replaceOp(op, {retVal, res});
     else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poison);
+    return ub::PoisonOp::create(builder, loc, type, poison);
 
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+      return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
     return value;
   };
 
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
     Value square =
-        rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+        arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
     return success();
   }
 
   // Replace `pow(x, -1.0)` with `1.0 / x`.
   if (isExponentValue(-1.0)) {
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+    Value one = arith::ConstantOp::create(
+        rewriter, loc,
+        rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
     return success();
   }
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
-    Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+    Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+    Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
                                                ValueRange{powHalf, powQuarter});
     return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+      return vector::BroadcastOp::create(rewriter, loc, vec, value);
     return value;
   };
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
   if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(opType, 1.0));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getFloatAttr(opType, 1.0));
   else
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(opType, 1));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getIntegerAttr(opType, 1));
 
   // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   //     with:
   //       (1 / x) * (1 / x) * (1 / x) * ...
   for (unsigned i = 1; i < exponentValue; ++i)
-    result = rewriter.create<MulOpTy>(loc, result, base);
+    result = MulOpTy::create(rewriter, loc, result, base);
 
   // Inverse the base for negative exponent, i.e. for
   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
   if (exponentIsNegative)
-    result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+    result = DivOpTy::create(rewriter, loc, bcast(one), result);
 
   rewriter.replaceOp(op, result);
   return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
                 APFloat::rmNearestTiesToEven, &losesInfo);
   auto attr = b.getFloatAttr(eltType, value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
                             OpBuilder &b) {
   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
   Type i64Ty = b.getI64Type();
   if (auto shapedTy = dyn_cast<ShapedType>(opType))
     i64Ty = shapedTy.clone(i64Ty);
-  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
-  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+  Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+  Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
   // The truncation does not preserve the sign when the truncated
   // value is -0. So here the sign is copied again.
-  return b.create<math::CopySignOp>(fpFixedConvert, operand);
+  return math::CopySignOp::create(b, fpFixedConvert, operand);
 }
 
 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value sub = b.create<arith::SubFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value sub = arith::SubFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(sub, half);
+  Value res = arith::MulFOp::create(b, sub, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value add = b.create<arith::AddFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value add = arith::AddFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(add, half);
+  Value res = arith::MulFOp::create(b, add, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
 
   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
-  Value isNegative = rewriter.create<arith::CmpFOp>(
-      loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+  Value isNegative = arith::CmpFOp::create(
+      rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
   Value isNegativeFloat =
-      rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+      arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
   Value isNegativeTimesNegTwo =
-      rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
-  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+      arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+  Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
 
   // Normalize input to positive value: y = sign(x) * x
-  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+  Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
 
   // Decompose on normalized input
-  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
-  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
-  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
-  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
-  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+  Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+  Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+  Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+  Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+  Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
 
   // Multiply result by sign(x) to retain signs from negative inputs
   rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type type = operand.getType();
-  Value sin = b.create<math::SinOp>(type, operand);
-  Value cos = b.create<math::CosOp>(type, operand);
-  Value div = b.create<arith::DivFOp>(type, sin, cos);
+  Value sin = math::SinOp::create(b, type, operand);
+  Value cos = math::CosOp::create(b, type, operand);
+  Value div = arith::DivFOp::create(b, type, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, one);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, one);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
   Type opType = operand.getType();
 
   Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, negOne);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value add = b.create<arith::AddFOp>(operand, one);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value sub = b.create<arith::AddFOp>(neg, one);
-  Value div = b.create<arith::DivFOp>(add, sub);
-  Value log = b.create<math::LogOp>(div);
+  Value add = arith::AddFOp::create(b, operand, one);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value sub = arith::AddFOp::create(b, neg, one);
+  Value div = arith::DivFOp::create(b, add, sub);
+  Value log = math::LogOp::create(b, div);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(log, half);
+  Value res = arith::MulFOp::create(b, log, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
-  Value add = b.create<arith::AddFOp>(type, mult, operandC);
+  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+  Value add = arith::AddFOp::create(b, type, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
 
-  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
-                                          fpFixedConvert);
-  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+  Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+                                        fpFixedConvert);
+  Value incrValue =
+      arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
   rewriter.replaceOp(op, ret);
   return success();
 }
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   auto convertFPowItoPowf = [&]() -> LogicalResult {
     Value castPowerToFp =
-        rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
-    Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
-                                              castPowerToFp);
+        arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+    Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+                                     castPowerToFp);
     rewriter.replaceOp(op, res);
     return success();
   };
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = b.create<arith::MulFOp>(baseType, base, res);
+      res = arith::MulFOp::create(b, baseType, base, res);
     absPower >>= 1;
-    base = b.create<arith::MulFOp>(baseType, base, base);
+    base = arith::MulFOp::create(b, baseType, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         createFloatConst(op->getLoc(), baseType,
                          APFloat::getInf(sem, /*Negative=*/true), rewriter);
     Value zeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
-    res = b.create<arith::DivFOp>(baseType, one, res);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+    res = arith::DivFOp::create(b, baseType, one, res);
     res =
-        b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
-    res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
-                 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

See #147168 for more info.


Patch is 130.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149926.diff

17 Files Affected:

  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+3-2)
  • (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+14-13)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+131-130)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+305-296)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+13-13)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+15-11)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+27-24)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+47-46)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp (+14-13)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (+49-47)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2)
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a460..a3e1542e6a947 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
   if (auto boolValue = dyn_cast<BoolAttr>(value)) {
     if (!type.isSignlessInteger(1))
       return nullptr;
-    return b.create<BoolConstantOp>(loc, type, boolValue);
+    return BoolConstantOp::create(b, loc, type, boolValue);
   }
 
   // Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
       return nullptr;
     assert(indexValue.getValue().getBitWidth() ==
            IndexType::kInternalStorageBitWidth);
-    return b.create<ConstantOp>(loc, indexValue);
+    return ConstantOp::create(b, loc, indexValue);
   }
 
   return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
 
   index::CmpOp newCmp;
   if (rhsIsZero)
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getLhs(), subOp.getRhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getLhs(), subOp.getRhs());
   else
-    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
-                                           subOp.getRhs(), subOp.getLhs());
+    newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+                                  subOp.getRhs(), subOp.getLhs());
   rewriter.replaceOp(op, newCmp);
   return success();
 }
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531..364e4d385fd62 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
 
     auto loc = globalStoreOp.getLoc();
-    auto targetMemref = rewriter.create<memref::GetGlobalOp>(
-        loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+    auto targetMemref = memref::GetGlobalOp::create(
+        rewriter, loc, memrefType,
+        globalStoreOp.getGlobalAttr().getLeafReference());
 
     auto sourceMemref =
         getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7..f52c3f99189d2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
     if (!isa<IntegerAttr>(dltiAttr.value()))
       return op->emitError()
              << "Expected an integer attribute for MPI:comm_world_rank";
-    Value res = b.create<arith::ConstantIndexOp>(
-        op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+    Value res = arith::ConstantIndexOp::create(
+        b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
     if (Value retVal = op.getRetval())
       b.replaceOp(op, {retVal, res});
     else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658..a21631cbf8510 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poison);
+    return ub::PoisonOp::create(builder, loc, type, poison);
 
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b2..31785eb20a642 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+      return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
     return value;
   };
 
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
     Value square =
-        rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+        arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
     return success();
   }
 
   // Replace `pow(x, -1.0)` with `1.0 / x`.
   if (isExponentValue(-1.0)) {
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+    Value one = arith::ConstantOp::create(
+        rewriter, loc,
+        rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
     return success();
   }
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
-    Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+    Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+    Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
                                                ValueRange{powHalf, powQuarter});
     return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+      return vector::BroadcastOp::create(rewriter, loc, vec, value);
     return value;
   };
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
   if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getFloatAttr(opType, 1.0));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getFloatAttr(opType, 1.0));
   else
-    one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(opType, 1));
+    one = arith::ConstantOp::create(rewriter, loc,
+                                    rewriter.getIntegerAttr(opType, 1));
 
   // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   //     with:
   //       (1 / x) * (1 / x) * (1 / x) * ...
   for (unsigned i = 1; i < exponentValue; ++i)
-    result = rewriter.create<MulOpTy>(loc, result, base);
+    result = MulOpTy::create(rewriter, loc, result, base);
 
   // Inverse the base for negative exponent, i.e. for
   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
   if (exponentIsNegative)
-    result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+    result = DivOpTy::create(rewriter, loc, bcast(one), result);
 
   rewriter.replaceOp(op, result);
   return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf..5edb6e28fb018 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
                 APFloat::rmNearestTiesToEven, &losesInfo);
   auto attr = b.getFloatAttr(eltType, value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
                             OpBuilder &b) {
   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return b.create<arith::ConstantOp>(loc,
-                                       DenseElementsAttr::get(shapedTy, attr));
+    return arith::ConstantOp::create(b, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
   }
 
-  return b.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(b, loc, attr);
 }
 
 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
   Type i64Ty = b.getI64Type();
   if (auto shapedTy = dyn_cast<ShapedType>(opType))
     i64Ty = shapedTy.clone(i64Ty);
-  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
-  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+  Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+  Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
   // The truncation does not preserve the sign when the truncated
   // value is -0. So here the sign is copied again.
-  return b.create<math::CopySignOp>(fpFixedConvert, operand);
+  return math::CopySignOp::create(b, fpFixedConvert, operand);
 }
 
 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value sub = b.create<arith::SubFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value sub = arith::SubFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(sub, half);
+  Value res = arith::MulFOp::create(b, sub, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
   Value operand = op.getOperand();
   Type opType = operand.getType();
 
-  Value exp = b.create<math::ExpOp>(operand);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value nexp = b.create<math::ExpOp>(neg);
-  Value add = b.create<arith::AddFOp>(exp, nexp);
+  Value exp = math::ExpOp::create(b, operand);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value nexp = math::ExpOp::create(b, neg);
+  Value add = arith::AddFOp::create(b, exp, nexp);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(add, half);
+  Value res = arith::MulFOp::create(b, add, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
 
   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
-  Value isNegative = rewriter.create<arith::CmpFOp>(
-      loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+  Value isNegative = arith::CmpFOp::create(
+      rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
   Value isNegativeFloat =
-      rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+      arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
   Value isNegativeTimesNegTwo =
-      rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
-  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+      arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+  Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
 
   // Normalize input to positive value: y = sign(x) * x
-  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+  Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
 
   // Decompose on normalized input
-  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
-  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
-  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
-  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
-  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+  Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+  Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+  Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+  Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+  Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
 
   // Multiply result by sign(x) to retain signs from negative inputs
   rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type type = operand.getType();
-  Value sin = b.create<math::SinOp>(type, operand);
-  Value cos = b.create<math::CosOp>(type, operand);
-  Value div = b.create<arith::DivFOp>(type, sin, cos);
+  Value sin = math::SinOp::create(b, type, operand);
+  Value cos = math::CosOp::create(b, type, operand);
+  Value div = arith::DivFOp::create(b, type, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, one);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, one);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
   Type opType = operand.getType();
 
   Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
-  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
-  Value sqrt = b.create<math::SqrtOp>(fma);
-  Value add = b.create<arith::AddFOp>(operand, sqrt);
-  Value res = b.create<math::LogOp>(add);
+  Value fma = math::FmaOp::create(b, operand, operand, negOne);
+  Value sqrt = math::SqrtOp::create(b, fma);
+  Value add = arith::AddFOp::create(b, operand, sqrt);
+  Value res = math::LogOp::create(b, add);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
   Type opType = operand.getType();
 
   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value add = b.create<arith::AddFOp>(operand, one);
-  Value neg = b.create<arith::NegFOp>(operand);
-  Value sub = b.create<arith::AddFOp>(neg, one);
-  Value div = b.create<arith::DivFOp>(add, sub);
-  Value log = b.create<math::LogOp>(div);
+  Value add = arith::AddFOp::create(b, operand, one);
+  Value neg = arith::NegFOp::create(b, operand);
+  Value sub = arith::AddFOp::create(b, neg, one);
+  Value div = arith::DivFOp::create(b, add, sub);
+  Value log = math::LogOp::create(b, div);
   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
-  Value res = b.create<arith::MulFOp>(log, half);
+  Value res = arith::MulFOp::create(b, log, half);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
-  Value add = b.create<arith::AddFOp>(type, mult, operandC);
+  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+  Value add = arith::AddFOp::create(b, type, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
 
-  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
-                                          fpFixedConvert);
-  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+  Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+                                        fpFixedConvert);
+  Value incrValue =
+      arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
   rewriter.replaceOp(op, ret);
   return success();
 }
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   auto convertFPowItoPowf = [&]() -> LogicalResult {
     Value castPowerToFp =
-        rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
-    Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
-                                              castPowerToFp);
+        arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+    Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+                                     castPowerToFp);
     rewriter.replaceOp(op, res);
     return success();
   };
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = b.create<arith::MulFOp>(baseType, base, res);
+      res = arith::MulFOp::create(b, baseType, base, res);
     absPower >>= 1;
-    base = b.create<arith::MulFOp>(baseType, base, base);
+    base = arith::MulFOp::create(b, baseType, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         createFloatConst(op->getLoc(), baseType,
                          APFloat::getInf(sem, /*Negative=*/true), rewriter);
     Value zeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
-    res = b.create<arith::DivFOp>(baseType, one, res);
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+    res = arith::DivFOp::create(b, baseType, one, res);
     res =
-        b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
-    res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
-                 ...
[truncated]

@makslevental makslevental requested review from fschlimb and newling July 21, 2025 22:45
@newling
Copy link
Contributor

newling commented Jul 22, 2025

Did you use a sed command for this, can you post it here for posterity? I'm curious why you chose to break this down into quite so many PRs (26?) , wouldn't it be more straightforward to run the regex on the whole project, once, and post a single PR?

btw thanks a million for this change, I'm looking forward to not using the convoluted OpName::build( workaround for getting signatures described in the editor!

@makslevental
Copy link
Contributor Author

makslevental commented Jul 22, 2025

Did you use a sed command for this, can you post it here for posterity?

(\w)\.create<(.*?)>\(

I'm curious why you chose to break this down into quite so many PRs (26?) , wouldn't it be more straightforward to run the regex on the whole project, once, and post a single PR?

To make each chunk reviewable so that if the regex did something wrong we had a chance of catching it.

btw thanks a million for this change, I'm looking forward to not using the convoluted OpName::build( workaround for getting signatures described in the editor!

👍

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regex looks like it worked as expected, LGTM.

OpName::createOrFold coming later?

@makslevental
Copy link
Contributor Author

makslevental commented Jul 22, 2025

OpName::createOrFold coming later?

Maybe? I haven't quite decided whether it's too much codegen at that point for a much less frequently used API. And under "shared object" build it wouldn't be DCEd by the linker so 🤷‍♂️

@makslevental makslevental merged commit b0312be into llvm:main Jul 22, 2025
18 checks passed
@makslevental makslevental deleted the makslevental/update-create-19n branch July 22, 2025 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants