Skip to content

Commit c3823af

Browse files
authored
[mlir][NFC] update mlir/Dialect create APIs (22/n) (#149929)
See #147168 for more info.
1 parent dce6679 commit c3823af

File tree

10 files changed

+161
-153
lines changed

10 files changed

+161
-153
lines changed

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
391391
builder.createBlock(&getBody());
392392

393393
// Add a spirv.mlir.merge op into the merge block.
394-
builder.create<spirv::MergeOp>(getLoc());
394+
spirv::MergeOp::create(builder, getLoc());
395395
}
396396

397397
//===----------------------------------------------------------------------===//
@@ -543,15 +543,15 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
543543
builder.createBlock(&getBody());
544544

545545
// Add a spirv.mlir.merge op into the merge block.
546-
builder.create<spirv::MergeOp>(getLoc());
546+
spirv::MergeOp::create(builder, getLoc());
547547
}
548548

549549
SelectionOp
550550
SelectionOp::createIfThen(Location loc, Value condition,
551551
function_ref<void(OpBuilder &builder)> thenBody,
552552
OpBuilder &builder) {
553553
auto selectionOp =
554-
builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
554+
spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
555555

556556
selectionOp.addMergeBlock(builder);
557557
Block *mergeBlock = selectionOp.getMergeBlock();
@@ -562,17 +562,17 @@ SelectionOp::createIfThen(Location loc, Value condition,
562562
OpBuilder::InsertionGuard guard(builder);
563563
thenBlock = builder.createBlock(mergeBlock);
564564
thenBody(builder);
565-
builder.create<spirv::BranchOp>(loc, mergeBlock);
565+
spirv::BranchOp::create(builder, loc, mergeBlock);
566566
}
567567

568568
// Build the header block.
569569
{
570570
OpBuilder::InsertionGuard guard(builder);
571571
builder.createBlock(thenBlock);
572-
builder.create<spirv::BranchConditionalOp>(
573-
loc, condition, thenBlock,
574-
/*trueArguments=*/ArrayRef<Value>(), mergeBlock,
575-
/*falseArguments=*/ArrayRef<Value>());
572+
spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
573+
/*trueArguments=*/ArrayRef<Value>(),
574+
mergeBlock,
575+
/*falseArguments=*/ArrayRef<Value>());
576576
}
577577

578578
return selectionOp;

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
178178
return failure();
179179

180180
Value addsVal =
181-
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
181+
spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
182182

183183
Value carrysVal =
184-
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
184+
spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
185185

186186
// Create empty struct
187-
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
187+
Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
188188
// Fill in adds at id 0
189189
Value intermediate =
190-
rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
190+
spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
191191
// Fill in carrys at id 1
192192
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
193193
intermediate, 1);
@@ -260,16 +260,16 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
260260
return failure();
261261

262262
Value lowBitsVal =
263-
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
263+
spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
264264

265265
Value highBitsVal =
266-
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
266+
spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
267267

268268
// Create empty struct
269-
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
269+
Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
270270
// Fill in lowBits at id 0
271271
Value intermediate =
272-
rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
272+
spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
273273
// Fill in highBits at id 1
274274
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
275275
intermediate, 1);
@@ -1309,11 +1309,11 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
13091309
auto storeOpAttributes =
13101310
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
13111311

1312-
auto selectOp = rewriter.create<spirv::SelectOp>(
1313-
selectionOp.getLoc(), trueValue.getType(),
1312+
auto selectOp = spirv::SelectOp::create(
1313+
rewriter, selectionOp.getLoc(), trueValue.getType(),
13141314
brConditionalOp.getCondition(), trueValue, falseValue);
1315-
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1316-
selectOp.getResult(), storeOpAttributes);
1315+
spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1316+
selectOp.getResult(), storeOpAttributes);
13171317

13181318
// `spirv.mlir.selection` is not needed anymore.
13191319
rewriter.eraseOp(op);

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,12 +940,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
940940
Attribute value, Type type,
941941
Location loc) {
942942
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
943-
return builder.create<ub::PoisonOp>(loc, type, poison);
943+
return ub::PoisonOp::create(builder, loc, type, poison);
944944

945945
if (!spirv::ConstantOp::isBuildableWith(type))
946946
return nullptr;
947947

948-
return builder.create<spirv::ConstantOp>(loc, type, value);
948+
return spirv::ConstantOp::create(builder, loc, type, value);
949949
}
950950

951951
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -651,26 +651,26 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
651651
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
652652
unsigned width = intType.getWidth();
653653
if (width == 1)
654-
return builder.create<spirv::ConstantOp>(loc, type,
655-
builder.getBoolAttr(false));
656-
return builder.create<spirv::ConstantOp>(
657-
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
654+
return spirv::ConstantOp::create(builder, loc, type,
655+
builder.getBoolAttr(false));
656+
return spirv::ConstantOp::create(
657+
builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
658658
}
659659
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
660-
return builder.create<spirv::ConstantOp>(
661-
loc, type, builder.getFloatAttr(floatType, 0.0));
660+
return spirv::ConstantOp::create(builder, loc, type,
661+
builder.getFloatAttr(floatType, 0.0));
662662
}
663663
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664664
Type elemType = vectorType.getElementType();
665665
if (llvm::isa<IntegerType>(elemType)) {
666-
return builder.create<spirv::ConstantOp>(
667-
loc, type,
666+
return spirv::ConstantOp::create(
667+
builder, loc, type,
668668
DenseElementsAttr::get(vectorType,
669669
IntegerAttr::get(elemType, 0).getValue()));
670670
}
671671
if (llvm::isa<FloatType>(elemType)) {
672-
return builder.create<spirv::ConstantOp>(
673-
loc, type,
672+
return spirv::ConstantOp::create(
673+
builder, loc, type,
674674
DenseFPElementsAttr::get(vectorType,
675675
FloatAttr::get(elemType, 0.0).getValue()));
676676
}
@@ -684,26 +684,26 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
684684
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
685685
unsigned width = intType.getWidth();
686686
if (width == 1)
687-
return builder.create<spirv::ConstantOp>(loc, type,
688-
builder.getBoolAttr(true));
689-
return builder.create<spirv::ConstantOp>(
690-
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
687+
return spirv::ConstantOp::create(builder, loc, type,
688+
builder.getBoolAttr(true));
689+
return spirv::ConstantOp::create(
690+
builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
691691
}
692692
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
693-
return builder.create<spirv::ConstantOp>(
694-
loc, type, builder.getFloatAttr(floatType, 1.0));
693+
return spirv::ConstantOp::create(builder, loc, type,
694+
builder.getFloatAttr(floatType, 1.0));
695695
}
696696
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697697
Type elemType = vectorType.getElementType();
698698
if (llvm::isa<IntegerType>(elemType)) {
699-
return builder.create<spirv::ConstantOp>(
700-
loc, type,
699+
return spirv::ConstantOp::create(
700+
builder, loc, type,
701701
DenseElementsAttr::get(vectorType,
702702
IntegerAttr::get(elemType, 1).getValue()));
703703
}
704704
if (llvm::isa<FloatType>(elemType)) {
705-
return builder.create<spirv::ConstantOp>(
706-
loc, type,
705+
return spirv::ConstantOp::create(
706+
builder, loc, type,
707707
DenseFPElementsAttr::get(vectorType,
708708
FloatAttr::get(elemType, 1.0).getValue()));
709709
}
@@ -1985,7 +1985,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
19851985

19861986
OpBuilder builder(parser.getContext());
19871987
builder.setInsertionPointToEnd(&block);
1988-
builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1988+
spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
19891989
result.location = wrappedOp->getLoc();
19901990

19911991
result.addTypes(wrappedOp->getResult(0).getType());

mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
105105
}
106106
}
107107

108-
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
109-
firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
108+
auto combinedModule =
109+
spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
110+
addressingModel, memoryModel, vceTriple);
110111
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
111112

112113
// In some cases, a symbol in the (current state of the) combined module is

mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
7070
varType =
7171
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
7272

73-
return builder.create<spirv::GlobalVariableOp>(
74-
funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
75-
abiInfo.getBinding());
73+
return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
74+
varName, abiInfo.getDescriptorSet(),
75+
abiInfo.getBinding());
7676
}
7777

7878
/// Gets the global variables that need to be specified as interface variable
@@ -146,17 +146,17 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
146146
return funcOp.emitRemark("lower entry point failure: could not select "
147147
"execution model based on 'spirv.target_env'");
148148

149-
builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
150-
interfaceVars);
149+
spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
150+
interfaceVars);
151151

152152
// Specifies the spirv.ExecutionModeOp.
153153
if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
154154
std::optional<ArrayRef<spirv::Capability>> caps =
155155
spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
156156
if (!caps || targetEnv.allows(*caps)) {
157-
builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
158-
spirv::ExecutionMode::LocalSize,
159-
workgroupSizeAttr.asArrayRef());
157+
spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
158+
spirv::ExecutionMode::LocalSize,
159+
workgroupSizeAttr.asArrayRef());
160160
// Erase workgroup size.
161161
entryPointAttr = spirv::EntryPointABIAttr::get(
162162
entryPointAttr.getContext(), DenseI32ArrayAttr(),
@@ -167,9 +167,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
167167
std::optional<ArrayRef<spirv::Capability>> caps =
168168
spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
169169
if (!caps || targetEnv.allows(*caps)) {
170-
builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
171-
spirv::ExecutionMode::SubgroupSize,
172-
*subgroupSize);
170+
spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
171+
spirv::ExecutionMode::SubgroupSize,
172+
*subgroupSize);
173173
// Erase subgroup size.
174174
entryPointAttr = spirv::EntryPointABIAttr::get(
175175
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
@@ -180,8 +180,8 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
180180
std::optional<ArrayRef<spirv::Capability>> caps =
181181
spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
182182
if (!caps || targetEnv.allows(*caps)) {
183-
builder.create<spirv::ExecutionModeOp>(
184-
funcOp.getLoc(), funcOp,
183+
spirv::ExecutionModeOp::create(
184+
builder, funcOp.getLoc(), funcOp,
185185
spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
186186
// Erase target width.
187187
entryPointAttr = spirv::EntryPointABIAttr::get(
@@ -259,7 +259,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
259259

260260
// Insert spirv::AddressOf and spirv::AccessChain operations.
261261
Value replacement =
262-
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
262+
spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
263263
// Check if the arg is a scalar or vector type. In that case, the value
264264
// needs to be loaded into registers.
265265
// TODO: This is loading value of the scalar into registers
@@ -269,9 +269,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
269269
if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
270270
auto zero =
271271
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
272-
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
273-
funcOp.getLoc(), replacement, zero.getConstant());
274-
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
272+
auto loadPtr = spirv::AccessChainOp::create(
273+
rewriter, funcOp.getLoc(), replacement, zero.getConstant());
274+
replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
275275
}
276276
signatureConverter.remapInput(argType.index(), replacement);
277277
}
@@ -308,7 +308,7 @@ void LowerABIAttributesPass::runOnOperation() {
308308
ValueRange inputs, Location loc) {
309309
if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
310310
return Value();
311-
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
311+
return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
312312
});
313313

314314
RewritePatternSet patterns(context);

mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void RewriteInsertsPass::runOnOperation() {
6565
operands.push_back(insertionOp.getObject());
6666

6767
OpBuilder builder(lastCompositeInsertOp);
68-
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
69-
location, compositeType, operands);
68+
auto compositeConstructOp = spirv::CompositeConstructOp::create(
69+
builder, location, compositeType, operands);
7070

7171
lastCompositeInsertOp.replaceAllUsesWith(
7272
compositeConstructOp->getResult(0));

0 commit comments

Comments
 (0)