From e4855dd7c5e50e2276a8bf7e01d22b0db8beed14 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 21 Jul 2025 18:21:25 -0400 Subject: [PATCH] [mlir][NFC] update `mlir/Dialect` create APIs (22/n) See https://github.com/llvm/llvm-project/pull/147168 for more info. --- mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 16 ++-- .../SPIRV/IR/SPIRVCanonicalization.cpp | 24 ++--- mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 4 +- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 42 ++++----- .../Linking/ModuleCombiner/ModuleCombiner.cpp | 5 +- .../Transforms/LowerABIAttributesPass.cpp | 36 +++---- .../SPIRV/Transforms/RewriteInsertsPass.cpp | 4 +- .../SPIRV/Transforms/SPIRVConversion.cpp | 93 ++++++++++--------- .../Transforms/SPIRVWebGPUTransforms.cpp | 48 +++++----- .../Transforms/UnifyAliasedResourcePass.cpp | 42 ++++----- 10 files changed, 161 insertions(+), 153 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 371456552b5b5..890406df74e72 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); + spirv::MergeOp::create(builder, getLoc()); } //===----------------------------------------------------------------------===// @@ -543,7 +543,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) { builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); + spirv::MergeOp::create(builder, getLoc()); } SelectionOp @@ -551,7 +551,7 @@ SelectionOp::createIfThen(Location loc, Value condition, function_ref thenBody, OpBuilder &builder) { auto selectionOp = - builder.create(loc, spirv::SelectionControl::None); + spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(builder); Block *mergeBlock = selectionOp.getMergeBlock(); @@ -562,17 +562,17 @@ SelectionOp::createIfThen(Location loc, Value condition, OpBuilder::InsertionGuard guard(builder); thenBlock = builder.createBlock(mergeBlock); thenBody(builder); - builder.create(loc, mergeBlock); + spirv::BranchOp::create(builder, loc, mergeBlock); } // Build the header block. { OpBuilder::InsertionGuard guard(builder); builder.createBlock(thenBlock); - builder.create( - loc, condition, thenBlock, - /*trueArguments=*/ArrayRef(), mergeBlock, - /*falseArguments=*/ArrayRef()); + spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock, + /*trueArguments=*/ArrayRef(), + mergeBlock, + /*falseArguments=*/ArrayRef()); } return selectionOp; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 047f8da0cc003..2bde44baf961e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -178,16 +178,16 @@ struct IAddCarryFold final : OpRewritePattern { return failure(); Value addsVal = - rewriter.create(loc, constituentType, adds); + spirv::ConstantOp::create(rewriter, loc, constituentType, adds); Value carrysVal = - rewriter.create(loc, constituentType, carrys); + spirv::ConstantOp::create(rewriter, loc, constituentType, carrys); // Create empty struct - Value undef = rewriter.create(loc, op.getType()); + Value undef = spirv::UndefOp::create(rewriter, loc, op.getType()); // Fill in adds at id 0 Value intermediate = - rewriter.create(loc, addsVal, undef, 0); + spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0); // Fill in carrys at id 1 rewriter.replaceOpWithNewOp(op, carrysVal, intermediate, 1); @@ -260,16 +260,16 @@ struct MulExtendedFold final : OpRewritePattern { return failure(); Value lowBitsVal = - rewriter.create(loc, constituentType, lowBits); + spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits); Value highBitsVal = - rewriter.create(loc, constituentType, highBits); + spirv::ConstantOp::create(rewriter, loc, constituentType, highBits); // Create empty struct - Value undef = rewriter.create(loc, op.getType()); + Value undef = spirv::UndefOp::create(rewriter, loc, op.getType()); // Fill in lowBits at id 0 Value intermediate = - rewriter.create(loc, lowBitsVal, undef, 0); + spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0); // Fill in highBits at id 1 rewriter.replaceOpWithNewOp(op, highBitsVal, intermediate, 1); @@ -1309,11 +1309,11 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern { auto storeOpAttributes = cast(trueBlock->front())->getAttrs(); - auto selectOp = rewriter.create( - selectionOp.getLoc(), trueValue.getType(), + auto selectOp = spirv::SelectOp::create( + rewriter, selectionOp.getLoc(), trueValue.getType(), brConditionalOp.getCondition(), trueValue, falseValue); - rewriter.create(selectOp.getLoc(), ptrValue, - selectOp.getResult(), storeOpAttributes); + spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue, + selectOp.getResult(), storeOpAttributes); // `spirv.mlir.selection` is not needed anymore. rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index f32c53b8f0b9e..c9a8e97bd3296 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -940,12 +940,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; - return builder.create(loc, type, value); + return spirv::ConstantOp::create(builder, loc, type, value); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 656236246b1ad..52c672a05fa43 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -651,26 +651,26 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) - return builder.create(loc, type, - builder.getBoolAttr(false)); - return builder.create( - loc, type, builder.getIntegerAttr(type, APInt(width, 0))); + return spirv::ConstantOp::create(builder, loc, type, + builder.getBoolAttr(false)); + return spirv::ConstantOp::create( + builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } if (auto floatType = llvm::dyn_cast(type)) { - return builder.create( - loc, type, builder.getFloatAttr(floatType, 0.0)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getFloatAttr(floatType, 0.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { - return builder.create( - loc, type, + return spirv::ConstantOp::create( + builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } if (llvm::isa(elemType)) { - return builder.create( - loc, type, + return spirv::ConstantOp::create( + builder, loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 0.0).getValue())); } @@ -684,26 +684,26 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) - return builder.create(loc, type, - builder.getBoolAttr(true)); - return builder.create( - loc, type, builder.getIntegerAttr(type, APInt(width, 1))); + return spirv::ConstantOp::create(builder, loc, type, + builder.getBoolAttr(true)); + return spirv::ConstantOp::create( + builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } if (auto floatType = llvm::dyn_cast(type)) { - return builder.create( - loc, type, builder.getFloatAttr(floatType, 1.0)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getFloatAttr(floatType, 1.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { - return builder.create( - loc, type, + return spirv::ConstantOp::create( + builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } if (llvm::isa(elemType)) { - return builder.create( - loc, type, + return spirv::ConstantOp::create( + builder, loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 1.0).getValue())); } @@ -1985,7 +1985,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); - builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); + spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0)); result.location = wrappedOp->getLoc(); result.addTypes(wrappedOp->getResult(0).getType()); diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp index 8da688806bade..2b9c7296830dc 100644 --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -105,8 +105,9 @@ OwningOpRef combine(ArrayRef inputModules, } } - auto combinedModule = combinedModuleBuilder.create( - firstModule.getLoc(), addressingModel, memoryModel, vceTriple); + auto combinedModule = + spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(), + addressingModel, memoryModel, vceTriple); combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); // In some cases, a symbol in the (current state of the) combined module is diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 85525a5a02fa2..81365b44a3aad 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -70,9 +70,9 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, varType = spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); - return builder.create( - funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), - abiInfo.getBinding()); + return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType, + varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); } /// Gets the global variables that need to be specified as interface variable @@ -146,17 +146,17 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, return funcOp.emitRemark("lower entry point failure: could not select " "execution model based on 'spirv.target_env'"); - builder.create(funcOp.getLoc(), *executionModel, funcOp, - interfaceVars); + spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp, + interfaceVars); // Specifies the spirv.ExecutionModeOp. if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) { std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::LocalSize); if (!caps || targetEnv.allows(*caps)) { - builder.create(funcOp.getLoc(), funcOp, - spirv::ExecutionMode::LocalSize, - workgroupSizeAttr.asArrayRef()); + spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp, + spirv::ExecutionMode::LocalSize, + workgroupSizeAttr.asArrayRef()); // Erase workgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), DenseI32ArrayAttr(), @@ -167,9 +167,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize); if (!caps || targetEnv.allows(*caps)) { - builder.create(funcOp.getLoc(), funcOp, - spirv::ExecutionMode::SubgroupSize, - *subgroupSize); + spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp, + spirv::ExecutionMode::SubgroupSize, + *subgroupSize); // Erase subgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), @@ -180,8 +180,8 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); if (!caps || targetEnv.allows(*caps)) { - builder.create( - funcOp.getLoc(), funcOp, + spirv::ExecutionModeOp::create( + builder, funcOp.getLoc(), funcOp, spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); // Erase target width. entryPointAttr = spirv::EntryPointABIAttr::get( @@ -259,7 +259,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( // Insert spirv::AddressOf and spirv::AccessChain operations. Value replacement = - rewriter.create(funcOp.getLoc(), var); + spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var); // Check if the arg is a scalar or vector type. In that case, the value // needs to be loaded into registers. // TODO: This is loading value of the scalar into registers @@ -269,9 +269,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( if (cast(argType.value()).isScalarOrVector()) { auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); - auto loadPtr = rewriter.create( - funcOp.getLoc(), replacement, zero.getConstant()); - replacement = rewriter.create(funcOp.getLoc(), loadPtr); + auto loadPtr = spirv::AccessChainOp::create( + rewriter, funcOp.getLoc(), replacement, zero.getConstant()); + replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr); } signatureConverter.remapInput(argType.index(), replacement); } @@ -308,7 +308,7 @@ void LowerABIAttributesPass::runOnOperation() { ValueRange inputs, Location loc) { if (inputs.size() != 1 || !isa(inputs[0].getType())) return Value(); - return builder.create(loc, type, inputs[0]).getResult(); + return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult(); }); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp index ab5898d0e3925..38ef547f0769f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -65,8 +65,8 @@ void RewriteInsertsPass::runOnOperation() { operands.push_back(insertionOp.getObject()); OpBuilder builder(lastCompositeInsertOp); - auto compositeConstructOp = builder.create( - location, compositeType, operands); + auto compositeConstructOp = spirv::CompositeConstructOp::create( + builder, location, compositeType, operands); lastCompositeInsertOp.replaceAllUsesWith( compositeConstructOp->getResult(0)); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index f70b3325f8725..35ec0190b5a61 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -669,21 +669,24 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, Location loc) { // We can only cast one value in SPIR-V. if (inputs.size() != 1) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } Value input = inputs.front(); // Only support integer types for now. Floating point types to be implemented. if (!isa(type)) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } auto inputType = cast(input.getType()); auto scalarType = dyn_cast(type); if (!scalarType) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } @@ -691,14 +694,15 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, // truncating to go back so we don't need to worry about the signedness. // For extension, we cannot have enough signal here to decide which op to use. if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } // Boolean values would need to use different ops than normal integer values. if (type.isInteger(1)) { Value one = spirv::ConstantOp::getOne(inputType, loc, builder); - return builder.create(loc, input, one); + return spirv::IEqualOp::create(builder, loc, input, one); } // Check that the source integer type is supported by the environment. @@ -708,7 +712,8 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, scalarType.getCapabilities(caps); if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || failed(checkExtensionRequirements(type, targetEnv, exts))) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } @@ -716,9 +721,9 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, // care about signedness here. Still try to use a corresponding op for better // consistency though. if (type.isSignedInteger()) { - return builder.create(loc, type, input); + return spirv::SConvertOp::create(builder, loc, type, input); } - return builder.create(loc, type, input); + return spirv::UConvertOp::create(builder, loc, type, input); } //===----------------------------------------------------------------------===// @@ -770,7 +775,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = - builder.create(loc, ptrType, name, builtin); + spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin); break; } case spirv::BuiltIn::SubgroupId: @@ -781,7 +786,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, spirv::PointerType::get(integerType, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = - builder.create(loc, ptrType, name, builtin); + spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin); break; } default: @@ -842,8 +847,8 @@ getOrInsertPushConstantVariable(Location loc, Block &block, auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); auto type = getPushConstantStorageType(elementCount, builder, indexType); const char *name = "__push_constant_var__"; - return builder.create(loc, type, name, - /*initializer=*/nullptr); + return spirv::GlobalVariableOp::create(builder, loc, type, name, + /*initializer=*/nullptr); } //===----------------------------------------------------------------------===// @@ -879,8 +884,8 @@ struct FuncOpConversion final : OpConversionPattern { } // Create the converted spirv.func op. - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), + auto newFuncOp = spirv::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), resultType ? TypeRange(resultType) : TypeRange())); @@ -919,8 +924,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern { } // Create a new func op with the original type and copy the function body. - auto newFuncOp = rewriter.create(funcOp.getLoc(), - funcOp.getName(), fnType); + auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(), + funcOp.getName(), fnType); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); @@ -954,8 +959,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern { auto origVecType = dyn_cast(origType); if (!origVecType) { // We need a placeholder for the old argument that will be erased later. - Value result = rewriter.create( - loc, origType, rewriter.getZeroAttr(origType)); + Value result = arith::ConstantOp::create( + rewriter, loc, origType, rewriter.getZeroAttr(origType)); rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); tmpOps.insert({result.getDefiningOp(), newInputNo}); oneToNTypeMapping.addInputs(origInputNo, origType); @@ -967,8 +972,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern { auto targetShape = getTargetShape(origVecType); if (!targetShape) { // We need a placeholder for the old argument that will be erased later. - Value result = rewriter.create( - loc, origType, rewriter.getZeroAttr(origType)); + Value result = arith::ConstantOp::create( + rewriter, loc, origType, rewriter.getZeroAttr(origType)); rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); tmpOps.insert({result.getDefiningOp(), newInputNo}); oneToNTypeMapping.addInputs(origInputNo, origType); @@ -982,12 +987,12 @@ struct FuncOpVectorUnroll final : OpRewritePattern { llvm::to_vector_of(origVecType.getShape()); // Prepare the result vector. - Value result = rewriter.create( - loc, origVecType, rewriter.getZeroAttr(origVecType)); + Value result = arith::ConstantOp::create( + rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType)); ++newOpCount; // Prepare the placeholder for the new arguments that will be added later. - Value dummy = rewriter.create( - loc, unrolledType, rewriter.getZeroAttr(unrolledType)); + Value dummy = arith::ConstantOp::create( + rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType)); ++newOpCount; // Create the `vector.insert_strided_slice` ops. @@ -995,8 +1000,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern { SmallVector newTypes; for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { - result = rewriter.create( - loc, dummy, result, offsets, strides); + result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy, + result, offsets, strides); newTypes.push_back(unrolledType); unrolledInputNums.push_back(newInputNo); ++newInputNo; @@ -1109,12 +1114,12 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { Value returnValue = returnOp.getOperand(origResultNo); for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { - Value result = rewriter.create( - loc, returnValue, offsets, extractShape, strides); + Value result = vector::ExtractStridedSliceOp::create( + rewriter, loc, returnValue, offsets, extractShape, strides); if (originalShape.size() > 1) { SmallVector extractIndices(originalShape.size() - 1, 0); result = - rewriter.create(loc, result, extractIndices); + vector::ExtractOp::create(rewriter, loc, result, extractIndices); } newOperands.push_back(result); newTypes.push_back(unrolledType); @@ -1132,7 +1137,7 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { // Replace the return op using the new operands. This will automatically // update the entry block as well. rewriter.replaceOp(returnOp, - rewriter.create(loc, newOperands)); + func::ReturnOp::create(rewriter, loc, newOperands)); return success(); } @@ -1157,8 +1162,8 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), builtin, integerType, builder, prefix, suffix); - Value ptr = builder.create(op->getLoc(), varOp); - return builder.create(op->getLoc(), ptr); + Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp); + return spirv::LoadOp::create(builder, op->getLoc(), ptr); } //===----------------------------------------------------------------------===// @@ -1179,12 +1184,12 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, loc, parent->getRegion(0).front(), elementCount, builder, integerType); Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); - Value offsetOp = builder.create( - loc, integerType, builder.getI32IntegerAttr(offset)); - auto addrOp = builder.create(loc, varOp); - auto acOp = builder.create( - loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); - return builder.create(loc, acOp); + Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType, + builder.getI32IntegerAttr(offset)); + auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp); + auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp, + llvm::ArrayRef({zeroOp, offsetOp})); + return spirv::LoadOp::create(builder, loc, acOp); } //===----------------------------------------------------------------------===// @@ -1244,7 +1249,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, linearizedIndices.push_back( linearizeIndex(indices, strides, offset, indexType, loc, builder)); } - return builder.create(loc, basePtr, linearizedIndices); + return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices); } Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, @@ -1275,11 +1280,11 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, cast(basePtr.getType()).getPointeeType(); if (isa(pointeeType)) { linearizedIndices.push_back(linearIndex); - return builder.create(loc, basePtr, - linearizedIndices); + return spirv::AccessChainOp::create(builder, loc, basePtr, + linearizedIndices); } - return builder.create(loc, basePtr, linearIndex, - linearizedIndices); + return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex, + linearizedIndices); } Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, @@ -1465,7 +1470,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, }); addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - auto cast = builder.create(loc, type, inputs); + auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index af1cf2a1373e3..e0900005ea1bb 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -64,16 +64,16 @@ static Value lowerExtendedMultiplication(Operation *mulOp, // and 4 additions after constant folding. // - With sign-extended arguments, we end up emitting 8 multiplications and // and 12 additions after CSE. - Value cstLowMask = rewriter.create( - loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); + Value cstLowMask = ConstantOp::create( + rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { - return rewriter.create(loc, val, cstLowMask); + return BitwiseAndOp::create(rewriter, loc, val, cstLowMask); }; - Value cst16 = rewriter.create(loc, lhs.getType(), - getScalarOrSplatAttr(argTy, 16)); + Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 16)); auto getHighDigit = [&rewriter, loc, cst16](Value val) { - return rewriter.create(loc, val, cst16); + return ShiftRightLogicalOp::create(rewriter, loc, val, cst16); }; auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { @@ -82,11 +82,11 @@ static Value lowerExtendedMultiplication(Operation *mulOp, // fine. We do not have to introduce an extra constant since any // value in [15, 32) would do. return getHighDigit( - rewriter.create(loc, val, cst16)); + ShiftRightArithmeticOp::create(rewriter, loc, val, cst16)); }; - Value cst0 = rewriter.create(loc, lhs.getType(), - getScalarOrSplatAttr(argTy, 0)); + Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 0)); Value lhsLow = getLowDigit(lhs); Value lhsHigh = getHighDigit(lhs); @@ -108,7 +108,7 @@ static Value lowerExtendedMultiplication(Operation *mulOp, continue; Value &thisResDigit = resultDigits[i + j]; - Value mul = rewriter.create(loc, lhsDigit, rhsDigit); + Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit); Value current = rewriter.createOrFold(loc, thisResDigit, mul); thisResDigit = getLowDigit(current); @@ -122,14 +122,15 @@ static Value lowerExtendedMultiplication(Operation *mulOp, } auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { - Value highBits = rewriter.create(loc, high, cst16); - return rewriter.create(loc, low, highBits); + Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16); + return BitwiseOrOp::create(rewriter, loc, low, highBits); }; Value low = combineDigits(resultDigits[0], resultDigits[1]); Value high = combineDigits(resultDigits[2], resultDigits[3]); - return rewriter.create( - loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); + return CompositeConstructOp::create(rewriter, loc, + mulOp->getResultTypes().front(), + llvm::ArrayRef({low, high})); } //===----------------------------------------------------------------------===// @@ -184,18 +185,19 @@ struct ExpandAddCarryPattern final : OpRewritePattern { loc, llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); - Value one = - rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 1)); - Value zero = - rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 0)); + Value one = ConstantOp::create(rewriter, loc, argTy, + getScalarOrSplatAttr(argTy, 1)); + Value zero = ConstantOp::create(rewriter, loc, argTy, + getScalarOrSplatAttr(argTy, 0)); // Calculate the carry by checking if the addition resulted in an overflow. - Value out = rewriter.create(loc, lhs, rhs); - Value cmp = rewriter.create(loc, out, lhs); - Value carry = rewriter.create(loc, cmp, one, zero); + Value out = IAddOp::create(rewriter, loc, lhs, rhs); + Value cmp = ULessThanOp::create(rewriter, loc, out, lhs); + Value carry = SelectOp::create(rewriter, loc, cmp, one, zero); - Value add = rewriter.create( - loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); + Value add = CompositeConstructOp::create(rewriter, loc, + op->getResultTypes().front(), + llvm::ArrayRef({out, carry})); rewriter.replaceOp(op, add); return success(); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 527d92634c196..692f2e7616e5a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -380,13 +380,13 @@ struct ConvertAccessChain : public ConvertAliasResource { Type indexType = oldIndex.getType(); int ratio = dstNumBytes / srcNumBytes; - auto ratioValue = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); + auto ratioValue = spirv::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); indices.back() = - rewriter.create(loc, indexType, oldIndex, ratioValue); - indices.push_back( - rewriter.create(loc, indexType, oldIndex, ratioValue)); + spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue); + indices.push_back(spirv::SModOp::create(rewriter, loc, indexType, + oldIndex, ratioValue)); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); @@ -407,11 +407,11 @@ struct ConvertAccessChain : public ConvertAliasResource { Type indexType = oldIndex.getType(); int ratio = srcNumBytes / dstNumBytes; - auto ratioValue = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); + auto ratioValue = spirv::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); indices.back() = - rewriter.create(loc, indexType, oldIndex, ratioValue); + spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); @@ -435,15 +435,15 @@ struct ConvertLoad : public ConvertAliasResource { auto dstElemType = cast(dstPtrType.getPointeeType()); Location loc = loadOp.getLoc(); - auto newLoadOp = rewriter.create(loc, adaptor.getPtr()); + auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr()); if (srcElemType == dstElemType) { rewriter.replaceOp(loadOp, newLoadOp->getResults()); return success(); } if (areSameBitwidthScalarType(srcElemType, dstElemType)) { - auto castOp = rewriter.create(loc, srcElemType, - newLoadOp.getValue()); + auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType, + newLoadOp.getValue()); rewriter.replaceOp(loadOp, castOp->getResults()); return success(); @@ -475,14 +475,14 @@ struct ConvertLoad : public ConvertAliasResource { auto indices = llvm::to_vector<4>(acOp.getIndices()); for (int i = 1; i < ratio; ++i) { // Load all subsequent components belonging to this element. - indices.back() = rewriter.create( - loc, i32Type, indices.back(), oneValue); - auto componentAcOp = rewriter.create( - loc, acOp.getBasePtr(), indices); + indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type, + indices.back(), oneValue); + auto componentAcOp = spirv::AccessChainOp::create( + rewriter, loc, acOp.getBasePtr(), indices); // Assuming little endian, this reads lower-ordered bits of the number // to lower-numbered components of the vector. components.push_back( - rewriter.create(loc, componentAcOp)); + spirv::LoadOp::create(rewriter, loc, componentAcOp)); } // Create a vector of the components and then cast back to the larger @@ -510,15 +510,15 @@ struct ConvertLoad : public ConvertAliasResource { castType = VectorType::get({count}, castType); for (Value &c : components) - c = rewriter.create(loc, castType, c); + c = spirv::BitcastOp::create(rewriter, loc, castType, c); } } - Value vectorValue = rewriter.create( - loc, vectorType, components); + Value vectorValue = spirv::CompositeConstructOp::create( + rewriter, loc, vectorType, components); if (!isa(srcElemType)) vectorValue = - rewriter.create(loc, srcElemType, vectorValue); + spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue); rewriter.replaceOp(loadOp, vectorValue); return success(); } @@ -546,7 +546,7 @@ struct ConvertStore : public ConvertAliasResource { Location loc = storeOp.getLoc(); Value value = adaptor.getValue(); if (srcElemType != dstElemType) - value = rewriter.create(loc, dstElemType, value); + value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value); rewriter.replaceOpWithNewOp(storeOp, adaptor.getPtr(), value, storeOp->getAttrs()); return success();