diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 0df91a243d07a..240491a51d2b9 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -340,7 +340,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.getStep(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult(); if (!stepped) return failure(); @@ -348,7 +348,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); auto branchOp = - rewriter.create(loc, conditionBlock, loopCarried); + cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the // llvm.loop_annotation attribute. @@ -375,16 +375,15 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, SmallVector destOperands; destOperands.push_back(lowerBound); llvm::append_range(destOperands, forOp.getInitArgs()); - rewriter.create(loc, conditionBlock, destOperands); + cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = rewriter.create( - loc, arith::CmpIPredicate::slt, iv, upperBound); + auto comparison = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound); - rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, - ArrayRef()); + cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, + ArrayRef(), endBlock, ArrayRef()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. @@ -409,7 +408,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), SmallVector(ifOp.getNumResults(), loc)); - rewriter.create(loc, remainingOpsBlock); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', @@ -419,7 +418,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, continueBlock, thenTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -433,15 +432,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock, elseTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ifOp.getCondition(), thenBlock, - /*trueArgs=*/ArrayRef(), elseBlock, - /*falseArgs=*/ArrayRef()); + cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock, + /*trueArgs=*/ArrayRef(), elseBlock, + /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.replaceOp(ifOp, continueBlock->getArguments()); @@ -459,13 +458,14 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, auto ®ion = op.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ®ion.front()); + cf::BranchOp::create(rewriter, loc, ®ion.front()); for (Block &block : region) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - rewriter.create(loc, remainingOpsBlock, terminatorOperands); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock, + terminatorOperands); rewriter.eraseOp(terminator); } } @@ -503,7 +503,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, for (auto [iv, lower, upper, step] : llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep())) { - ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); + ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs); ivs.push_back(forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); @@ -517,7 +517,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, forOp.getResults()); + scf::YieldOp::create(rewriter, loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -549,7 +549,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // has been already created in loop construction). if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, yieldOperands); + scf::YieldOp::create(rewriter, loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); @@ -575,7 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, before, whileOp.getInits()); + cf::BranchOp::create(rewriter, loc, before, whileOp.getInits()); // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last @@ -625,14 +625,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); + cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); auto condOp = cast(before->getTerminator()); - rewriter.create(condOp.getLoc(), condOp.getCondition(), - before, condOp.getArgs(), continuation, - ValueRange()); + cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(), + before, condOp.getArgs(), continuation, + ValueRange()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. @@ -695,12 +695,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, SmallVector caseOperands(caseSuccessors.size(), {}); // Cast switch index to integer case value. - Value caseValue = rewriter.create( - op.getLoc(), rewriter.getI32Type(), op.getArg()); + Value caseValue = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg()); - rewriter.create( - op.getLoc(), caseValue, *defaultBlock, ValueRange(), - rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); + cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock, + ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues), + caseSuccessors, caseOperands); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index dcb48529a74e6..84cbd869c78ef 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -91,7 +91,7 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = - rewriter.create(loc, varType, noInit); + emitc::VariableOp::create(rewriter, loc, varType, noInit); resultVariables.push_back(var); } @@ -103,14 +103,14 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, static void assignValues(ValueRange values, ValueRange variables, ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) - rewriter.create(loc, var, value); + emitc::AssignOp::create(rewriter, loc, var, value); } SmallVector loadValues(const SmallVector &variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast(var.getType()).getValueType(); - return rewriter.create(loc, type, var).getResult(); + return emitc::LoadOp::create(rewriter, loc, type, var).getResult(); }); } @@ -129,7 +129,7 @@ static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, assignValues(yieldOperands, resultVariables, rewriter, loc); - rewriter.create(loc); + emitc::YieldOp::create(rewriter, loc); rewriter.eraseOp(yield); return success(); @@ -164,8 +164,9 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); - emitc::ForOp loweredFor = rewriter.create( - loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); + emitc::ForOp loweredFor = + emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(), + adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -257,7 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create(loc, adaptor.getCondition(), false, false); + emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); auto result = lowerRegion(thenRegion, loweredThenRegion); @@ -304,8 +305,9 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( "create variables for results failed"); } - auto loweredSwitch = rewriter.create( - loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); + auto loweredSwitch = + emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(), + adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. for (auto pair : diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 844e66e927c4d..f191f3502cf5a 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -84,8 +84,8 @@ static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create(forOp.getLoc(), - forOp.getStepAsInt()); + return arith::ConstantIndexOp::create(builder, forOp.getLoc(), + forOp.getStepAsInt()); } // Get a Value for the loop lower bound. If the value requires computation, @@ -190,12 +190,12 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) { return std::nullopt; } - Value range = builder.create(currentLoop.getLoc(), - upperBound, lowerBound); + Value range = arith::SubIOp::create(builder, currentLoop.getLoc(), + upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (getConstantIntValue(step) != static_cast(1)) - range = - builder.create(currentLoop.getLoc(), range, step); + range = arith::CeilDivSIOp::create(builder, currentLoop.getLoc(), range, + step); dims.push_back(range); lbs.push_back(lowerBound); @@ -221,7 +221,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // no loop mapped to a specific dimension, use constant "1" as its size. Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) + ? arith::ConstantIndexOp::create(builder, rootForOp.getLoc(), 1) : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; @@ -232,9 +232,9 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // Create a launch op and move the body region of the innermost loop to the // launch op. - auto launchOp = builder.create( - rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, - blockSizeY, blockSizeZ); + auto launchOp = + gpu::LaunchOp::create(builder, rootForOp.getLoc(), gridSizeX, gridSizeY, + gridSizeZ, blockSizeX, blockSizeY, blockSizeZ); // Replace the loop terminator (loops contain only a single block) with the // gpu terminator and move the operations from the loop body block to the gpu @@ -244,7 +244,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, Location terminatorLoc = terminator.getLoc(); terminator.erase(); builder.setInsertionPointToEnd(innermostForOp.getBody()); - builder.create(terminatorLoc, TypeRange()); + gpu::TerminatorOp::create(builder, terminatorLoc, TypeRange()); launchOp.getBody().front().getOperations().splice( launchOp.getBody().front().begin(), innermostForOp.getBody()->getOperations()); @@ -263,10 +263,10 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (getConstantIntValue(step) != static_cast(1)) - id = builder.create(rootForOp.getLoc(), step, id); + id = arith::MulIOp::create(builder, rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + arith::AddIOp::create(builder, rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -319,8 +319,8 @@ static Value deriveStaticUpperBound(Value upperBound, if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.getMap().getResults()) { if (auto constExpr = dyn_cast(result)) { - return rewriter.create(minOp.getLoc(), - constExpr.getValue()); + return arith::ConstantIndexOp::create(rewriter, minOp.getLoc(), + constExpr.getValue()); } } } @@ -344,8 +344,8 @@ static Value deriveStaticUpperBound(Value upperBound, if ((lhs.value() < 0) != (rhs.value() < 0)) return {}; - return rewriter.create( - multiplyOp.getLoc(), lhs.value() * rhs.value()); + return arith::ConstantIndexOp::create(rewriter, multiplyOp.getLoc(), + lhs.value() * rhs.value()); } } @@ -422,8 +422,8 @@ static LogicalResult processParallelLoop( if (launchIndependent(val)) return val; if (auto constOp = val.getDefiningOp()) - return rewriter.create(constOp.getLoc(), - constOp.getValue()); + return arith::ConstantOp::create(rewriter, constOp.getLoc(), + constOp.getValue()); return {}; }; @@ -453,8 +453,8 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); - newIndex = rewriter.create( - loc, annotation.getMap().compose(lowerAndStep), + newIndex = AffineApplyOp::create( + rewriter, loc, annotation.getMap().compose(lowerAndStep), ValueRange{operand, ensureLaunchIndependent(step), ensureLaunchIndependent(lowerBound)}); // If there was also a bound, insert that, too. @@ -498,8 +498,8 @@ static LogicalResult processParallelLoop( 1, 2, ((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0)) .ceilDiv(rewriter.getAffineSymbolExpr(1)))); - Value launchBound = rewriter.create( - loc, annotation.getBound().compose(stepMap), + Value launchBound = AffineApplyOp::create( + rewriter, loc, annotation.getBound().compose(stepMap), ValueRange{ ensureLaunchIndependent( cloningMap.lookupOrDefault(upperBound)), @@ -517,10 +517,10 @@ static LogicalResult processParallelLoop( if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - arith::CmpIOp pred = rewriter.create( - loc, arith::CmpIPredicate::slt, newIndex, + arith::CmpIOp pred = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); - scf::IfOp ifOp = rewriter.create(loc, pred, false); + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Put a sentinel into the worklist so we know when to pop out of the // if body again. We use the launchOp here, as that cannot be part of @@ -530,10 +530,10 @@ static LogicalResult processParallelLoop( } } else { // Create a sequential for loop. - auto loopOp = rewriter.create( - loc, cloningMap.lookupOrDefault(lowerBound), - cloningMap.lookupOrDefault(upperBound), - cloningMap.lookupOrDefault(step)); + auto loopOp = scf::ForOp::create(rewriter, loc, + cloningMap.lookupOrDefault(lowerBound), + cloningMap.lookupOrDefault(upperBound), + cloningMap.lookupOrDefault(step)); newIndex = loopOp.getInductionVar(); rewriter.setInsertionPointToStart(loopOp.getBody()); // Put a sentinel into the worklist so we know when to pop out of the loop @@ -608,12 +608,12 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); Value constantOne = - rewriter.create(parallelOp.getLoc(), 1); - gpu::LaunchOp launchOp = rewriter.create( - parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, - constantOne, constantOne); + arith::ConstantIndexOp::create(rewriter, parallelOp.getLoc(), 1); + gpu::LaunchOp launchOp = gpu::LaunchOp::create( + rewriter, parallelOp.getLoc(), constantOne, constantOne, constantOne, + constantOne, constantOne, constantOne); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); - rewriter.create(loc); + gpu::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&launchOp.getBody().front()); IRMapping cloningMap; @@ -667,7 +667,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, if (externalValues.size()) return failure(); // Replace by gpu.all_reduce. - auto gpuRedOp = rewriter.create(loc, newValue); + auto gpuRedOp = gpu::AllReduceOp::create(rewriter, loc, newValue); cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult()); // Copy region. rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(), diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 584ac2f11b670..34f372af1e4b5 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -187,8 +187,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); - auto decl = builder.create(reduce.getLoc(), - "__scf_reduction", type); + auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), + "__scf_reduction", type); symbolTable.insert(decl); builder.createBlock(&decl.getInitializerRegion(), @@ -196,8 +196,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, {reduce.getOperands()[reductionIndex].getLoc()}); builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); Value init = - builder.create(reduce.getLoc(), type, initValue); - builder.create(reduce.getLoc(), init); + LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue); + omp::YieldOp::create(builder, reduce.getLoc(), init); Operation *terminator = &reduce.getReductions()[reductionIndex].front().back(); @@ -227,12 +227,12 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, {reduceOperandLoc, reduceOperandLoc}); Block *atomicBlock = &decl.getAtomicReductionRegion().back(); builder.setInsertionPointToEnd(atomicBlock); - Value loaded = builder.create(reduce.getLoc(), decl.getType(), - atomicBlock->getArgument(1)); - builder.create(reduce.getLoc(), atomicKind, - atomicBlock->getArgument(0), loaded, - LLVM::AtomicOrdering::monotonic); - builder.create(reduce.getLoc(), ArrayRef()); + Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(), + atomicBlock->getArgument(1)); + LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind, + atomicBlock->getArgument(0), loaded, + LLVM::AtomicOrdering::monotonic); + omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef()); return decl; } @@ -380,8 +380,9 @@ struct ParallelOpLowering : public OpRewritePattern { // Allocate reduction variables. Make sure the we don't overflow the stack // with local `alloca`s by saving and restoring the stack pointer. Location loc = parallelOp.getLoc(); - Value one = rewriter.create( - loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); + Value one = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64), + rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext()); @@ -390,9 +391,9 @@ struct ParallelOpLowering : public OpRewritePattern { isa(init.getType())) && "cannot create a reduction variable if the type is not an LLVM " "pointer element"); - Value storage = - rewriter.create(loc, ptrType, init.getType(), one, 0); - rewriter.create(loc, init, storage); + Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType, + init.getType(), one, 0); + LLVM::StoreOp::create(rewriter, loc, init, storage); reductionVariables.push_back(storage); } @@ -411,8 +412,8 @@ struct ParallelOpLowering : public OpRewritePattern { assert(redRegion.hasOneBlock() && "expect reduction region to have one block"); Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc); - Value pvtRedVal = rewriter.create(reduce.getLoc(), - rD.getType(), pvtRedVar); + Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(), + rD.getType(), pvtRedVar); // Make a copy of the reduction combiner region in the body mlir::OpBuilder builder(rewriter.getContext()); builder.setInsertionPoint(reduce); @@ -427,7 +428,7 @@ struct ParallelOpLowering : public OpRewritePattern { assert(yieldOp && yieldOp.getResults().size() == 1 && "expect YieldOp in reduction region to return one result"); Value redVal = yieldOp.getResults()[0]; - rewriter.create(loc, redVal, pvtRedVar); + LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar); rewriter.eraseOp(yieldOp); break; } @@ -437,12 +438,12 @@ struct ParallelOpLowering : public OpRewritePattern { Value numThreadsVar; if (numThreads > 0) { - numThreadsVar = rewriter.create( - loc, rewriter.getI32IntegerAttr(numThreads)); + numThreadsVar = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); } // Create the parallel wrapper. - auto ompParallel = rewriter.create( - loc, + auto ompParallel = omp::ParallelOp::create( + rewriter, loc, /* allocate_vars = */ llvm::SmallVector{}, /* allocator_vars = */ llvm::SmallVector{}, /* if_expr = */ Value{}, @@ -464,7 +465,7 @@ struct ParallelOpLowering : public OpRewritePattern { { OpBuilder::InsertionGuard allocaGuard(rewriter); // Create worksharing loop wrapper. - auto wsloopOp = rewriter.create(parallelOp.getLoc()); + auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc()); if (!reductionVariables.empty()) { wsloopOp.setReductionSymsAttr( ArrayAttr::get(rewriter.getContext(), reductionSyms)); @@ -476,7 +477,7 @@ struct ParallelOpLowering : public OpRewritePattern { wsloopOp.setReductionByref( DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef)); } - rewriter.create(loc); // omp.parallel terminator. + omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator. // The wrapper's entry block arguments will define the reduction // variables. @@ -490,8 +491,8 @@ struct ParallelOpLowering : public OpRewritePattern { parallelOp.getLoc())); // Create loop nest and populate region with contents of scf.parallel. - auto loopOp = rewriter.create( - parallelOp.getLoc(), parallelOp.getLowerBound(), + auto loopOp = omp::LoopNestOp::create( + rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), @@ -511,13 +512,13 @@ struct ParallelOpLowering : public OpRewritePattern { rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); rewriter.setInsertionPointToStart(&loopOpEntryBlock); - auto scope = rewriter.create(parallelOp.getLoc(), - TypeRange()); - rewriter.create(loc, ValueRange()); + auto scope = memref::AllocaScopeOp::create( + rewriter, parallelOp.getLoc(), TypeRange()); + omp::YieldOp::create(rewriter, loc, ValueRange()); Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); rewriter.mergeBlocks(ops, scopeBlock); rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); - rewriter.create(loc, ValueRange()); + memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange()); } } @@ -526,7 +527,7 @@ struct ParallelOpLowering : public OpRewritePattern { results.reserve(reductionVariables.size()); for (auto [variable, type] : llvm::zip(reductionVariables, parallelOp.getResultTypes())) { - Value res = rewriter.create(loc, type, variable); + Value res = LLVM::LoadOp::create(rewriter, loc, type, variable); results.push_back(res); } rewriter.replaceOp(parallelOp, results); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 78d13278fef53..dc92367fc58cd 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, auto pointerType = spirv::PointerType::get(convertedType, spirv::StorageClass::Function); rewriter.setInsertionPoint(newOp); - auto alloc = rewriter.create( - loc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); allocas.push_back(alloc); rewriter.setInsertionPointAfter(newOp); - Value loadResult = rewriter.create(loc, alloc); + Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc); resultValue.push_back(loadResult); } rewriter.replaceOp(scfOp, resultValue); @@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern { // a single back edge from the continue to header block, and a single exit // from header to merge. auto loc = forOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); @@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern { args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); - rewriter.create(loc, header, args); + spirv::BranchOp::create(rewriter, loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create( - loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); + auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); - rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, + ArrayRef(), mergeBlock, + ArrayRef()); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern { rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create( - loc, newIndVar.getType(), newIndVar, adaptor.getStep()); - rewriter.create(loc, header, updatedIndVar); + Value updatedIndVar = spirv::IAddOp::create( + rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep()); + spirv::BranchOp::create(rewriter, loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get // converted to CooperativeMatrix or to Vector type, to avoid having complex @@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern { // Create `spirv.selection` operation, selection header block and merge // block. - auto selectionOp = - rewriter.create(loc, spirv::SelectionControl::None); + auto selectionOp = spirv::SelectionOp::create( + rewriter, loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); - rewriter.create(loc); + spirv::MergeOp::create(rewriter, loc); OpBuilder::InsertionGuard guard(rewriter); auto *selectionHeaderBlock = @@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern { auto &thenRegion = ifOp.getThenRegion(); auto *thenBlock = &thenRegion.front(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(thenRegion, mergeBlock); auto *elseBlock = mergeBlock; @@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern { auto &elseRegion = ifOp.getElseRegion(); elseBlock = &elseRegion.front(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(elseRegion, mergeBlock); } // Create a `spirv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, adaptor.getCondition(), - thenBlock, ArrayRef(), - elseBlock, ArrayRef()); + spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(), + thenBlock, ArrayRef(), elseBlock, + ArrayRef()); replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, returnTypes); @@ -310,7 +312,7 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern { auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) - rewriter.create(loc, allocas[i], operands[i]); + spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]); if (isa(parent)) { // For loops we also need to update the branch jumping back to the // header. @@ -319,8 +321,8 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern { SmallVector args(br.getBlockArguments()); args.append(operands.begin(), operands.end()); rewriter.setInsertionPoint(br); - rewriter.create(terminatorOp.getLoc(), br.getTarget(), - args); + spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(), + args); rewriter.eraseOp(br); } } @@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern { matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); @@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern { // Jump from the loop entry block to the loop header block. rewriter.setInsertionPointToEnd(&entryBlock); - rewriter.create(loc, &beforeBlock, adaptor.getInits()); + spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits()); auto condLoc = cond.getLoc(); @@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern { // Create local variables before the scf.while op. rewriter.setInsertionPoint(loopOp); - auto alloc = rewriter.create( - condLoc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); // Load the final result values after the scf.while op. rewriter.setInsertionPointAfter(loopOp); - auto loadResult = rewriter.create(condLoc, alloc); + auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc); resultValues[i] = loadResult; // Store the current iteration's result value. rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.create(condLoc, alloc, res); + spirv::StoreOp::create(rewriter, condLoc, alloc, res); } rewriter.setInsertionPointToEnd(&beforeBlock); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index d7ae9f0e94fe8..035f197b1eac2 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -68,7 +68,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { /// Copies the given number of bytes from src to dst pointers. static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder) { - builder.create(loc, dst, src, size, /*isVolatile=*/false); + LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false); } /// Encodes the binding and descriptor set numbers into a new symbolic name. @@ -194,8 +194,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { if (!kernelFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - kernelFunc = rewriter.create( - rewriter.getUnknownLoc(), newKernelFuncName, + kernelFunc = LLVM::LLVMFuncOp::create( + rewriter, rewriter.getUnknownLoc(), newKernelFuncName, LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), ArrayRef())); rewriter.setInsertionPoint(launchOp); @@ -245,8 +245,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { if (!dstGlobal) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - dstGlobal = rewriter.create( - loc, dstGlobalType, + dstGlobal = LLVM::GlobalOp::create( + rewriter, loc, dstGlobalType, /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), /*alignment=*/0); rewriter.setInsertionPoint(launchOp); @@ -255,8 +255,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { // Copy the data from src operand pointer to dst global variable. Save // src, dst and size so that we can copy data back after emulating the // kernel call. - Value dst = rewriter.create( - loc, typeConverter->convertType(spirvGlobal.getType()), + Value dst = LLVM::AddressOfOp::create( + rewriter, loc, typeConverter->convertType(spirvGlobal.getType()), dstGlobal.getSymName()); copy(loc, dst, src, sizeBytes, rewriter); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 1d92b5d5562b5..aae3271371c1f 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -94,13 +94,13 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (isa(srcType)) { - return rewriter.create( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), minusOneIntegerAttribute(srcType, rewriter))); } - return rewriter.create( - loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + minusOneIntegerAttribute(srcType, rewriter)); } /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. @@ -108,14 +108,14 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { if (auto vecType = dyn_cast(srcType)) { auto floatType = cast(vecType.getElementType()); - return rewriter.create( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } auto floatType = cast(srcType); - return rewriter.create( - loc, dstType, rewriter.getFloatAttr(floatType, value)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + rewriter.getFloatAttr(floatType, value)); } /// Utility function for bitfield ops: @@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) - return rewriter.create(loc, llvmType, value); + return LLVM::ZExtOp::create(rewriter, loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) - return rewriter.create(loc, llvmType, value); + return LLVM::TruncOp::create(rewriter, loc, llvmType, value); return value; } @@ -151,12 +151,12 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); - Value broadcasted = rewriter.create(loc, llvmVectorType); + Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { - auto index = rewriter.create( - loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); - broadcasted = rewriter.create( - loc, llvmVectorType, broadcasted, toBroadcast, index); + auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, + rewriter.getI32IntegerAttr(i)); + broadcasted = LLVM::InsertElementOp::create( + rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } @@ -217,8 +217,8 @@ static Type convertStructTypePacked(spirv::StructType type, /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { - return rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 32), + return LLVM::ConstantOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -322,8 +322,9 @@ class AccessChainPattern : public SPIRVToLLVMConversion { auto llvmIndexType = getTypeConverter()->convertType(indexType); if (!llvmIndexType) return rewriter.notifyMatchFailure(op, "type conversion failed"); - Value zero = rewriter.create( - op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); + Value zero = + LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType, + rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); auto elementType = getTypeConverter()->convertType( @@ -375,20 +376,20 @@ class BitFieldInsertPattern // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create(loc, dstType, minusOne, count); - Value negated = rewriter.create(loc, dstType, - maskShiftedByCount, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value negated = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = - rewriter.create(loc, dstType, negated, offset); - Value mask = rewriter.create( - loc, dstType, maskShiftedByCountAndOffset, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = - rewriter.create(loc, dstType, op.getBase(), mask); + LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask); Value insertShiftedByOffset = - rewriter.create(loc, dstType, op.getInsert(), offset); + LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); @@ -470,23 +471,23 @@ class BitFieldSExtractPattern auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = isa(srcType) - ? rewriter.create( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), baseSize)) - : rewriter.create(loc, dstType, baseSize); + : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = - rewriter.create(loc, dstType, count, offset); + LLVM::AddOp::create(rewriter, loc, dstType, count, offset); Value amountToShiftLeft = - rewriter.create(loc, dstType, size, countPlusOffset); - Value baseShiftedLeft = rewriter.create( - loc, dstType, op.getBase(), amountToShiftLeft); + LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset); + Value baseShiftedLeft = LLVM::ShlOp::create( + rewriter, loc, dstType, op.getBase(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = - rewriter.create(loc, dstType, offset, amountToShiftLeft); + LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); @@ -516,13 +517,13 @@ class BitFieldUExtractPattern // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create(loc, dstType, minusOne, count); - Value mask = rewriter.create(loc, dstType, maskShiftedByCount, - minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount, + minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = - rewriter.create(loc, dstType, op.getBase(), offset); + LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset); rewriter.replaceOpWithNewOp(op, dstType, shiftedBase, mask); return success(); } @@ -694,8 +695,8 @@ class ExecutionModePattern auto structType = LLVM::LLVMStructType::getLiteral(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. - auto global = rewriter.create( - UnknownLoc::get(context), structType, /*isConstant=*/true, + auto global = LLVM::GlobalOp::create( + rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true, LLVM::Linkage::External, executionModeInfoName, Attribute(), /*alignment=*/0); Location loc = global.getLoc(); @@ -704,22 +705,23 @@ class ExecutionModePattern // Initialize the struct and set the execution mode value. rewriter.setInsertionPointToStart(block); - Value structValue = rewriter.create(loc, structType); - Value executionMode = rewriter.create( - loc, llvmI32Type, + Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType); + Value executionMode = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getI32IntegerAttr( static_cast(executionModeAttr.getValue()))); - structValue = rewriter.create(loc, structValue, - executionMode, 0); + SmallVector position{0}; + structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue, + executionMode, position); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { auto attr = values.getValue()[i]; - Value entry = rewriter.create(loc, llvmI32Type, attr); - structValue = rewriter.create( - loc, structValue, entry, ArrayRef({1, i})); + Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr); + structValue = LLVM::InsertValueOp::create( + rewriter, loc, structValue, entry, ArrayRef({1, i})); } - rewriter.create(loc, ArrayRef({structValue})); + LLVM::ReturnOp::create(rewriter, loc, ArrayRef({structValue})); rewriter.eraseOp(op); return success(); } @@ -913,7 +915,7 @@ class InverseSqrtPattern Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value sqrt = rewriter.create(loc, dstType, op.getOperand()); + Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand()); rewriter.replaceOpWithNewOp(op, dstType, one, sqrt); return success(); } @@ -973,10 +975,10 @@ class NotPattern : public SPIRVToLLVMConversion { IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = isa(srcType) - ? rewriter.create( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), minusOne)) - : rewriter.create(loc, dstType, minusOne); + : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1034,8 +1036,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, return func; OpBuilder b(symbolTable->getRegion(0)); - func = b.create( - symbolTable->getLoc(), name, + func = LLVM::LLVMFuncOp::create( + b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); func.setConvergent(convergent); @@ -1047,7 +1049,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = builder.create(loc, func, args); + auto call = LLVM::CallOp::create(builder, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -1078,12 +1080,12 @@ class ControlBarrierPattern : public SPIRVToLLVMConversion { lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); Location loc = controlBarrierOp->getLoc(); - Value execution = rewriter.create( - loc, i32, static_cast(adaptor.getExecutionScope())); - Value memory = rewriter.create( - loc, i32, static_cast(adaptor.getMemoryScope())); - Value semantics = rewriter.create( - loc, i32, static_cast(adaptor.getMemorySemantics())); + Value execution = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast(adaptor.getExecutionScope())); + Value memory = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast(adaptor.getMemoryScope())); + Value semantics = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast(adaptor.getMemorySemantics())); auto call = createSPIRVBuiltinCall(loc, rewriter, func, {execution, memory, semantics}); @@ -1255,10 +1257,12 @@ class GroupReducePattern : public SPIRVToLLVMConversion { lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); Location loc = op.getLoc(); - Value scope = rewriter.create( - loc, i32Ty, static_cast(adaptor.getExecutionScope())); - Value groupOp = rewriter.create( - loc, i32Ty, static_cast(adaptor.getGroupOperation())); + Value scope = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast(adaptor.getExecutionScope())); + Value groupOp = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast(adaptor.getGroupOperation())); SmallVector operands{scope, groupOp}; operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); @@ -1368,7 +1372,7 @@ class LoopPattern : public SPIRVToLLVMConversion { return failure(); Block *headerBlock = loopOp.getHeaderBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, brOp.getBlockArguments(), headerBlock); + LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock); rewriter.eraseBlock(entryBlock); // Branch from merge block to end block. @@ -1376,7 +1380,7 @@ class LoopPattern : public SPIRVToLLVMConversion { Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create(loc, terminatorOperands, endBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock); rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); @@ -1434,16 +1438,15 @@ class SelectionPattern : public SPIRVToLLVMConversion { Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create(loc, terminatorOperands, continueBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, condBrOp.getCondition(), trueBlock, - condBrOp.getTrueTargetOperands(), - falseBlock, - condBrOp.getFalseTargetOperands()); + LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock, + condBrOp.getTrueTargetOperands(), falseBlock, + condBrOp.getFalseTargetOperands()); rewriter.eraseBlock(headerBlock); rewriter.inlineRegionBefore(op.getBody(), continueBlock); @@ -1521,8 +1524,8 @@ class TanPattern : public SPIRVToLLVMConversion { return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); Location loc = tanOp.getLoc(); - Value sin = rewriter.create(loc, dstType, tanOp.getOperand()); - Value cos = rewriter.create(loc, dstType, tanOp.getOperand()); + Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); + Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); return success(); } @@ -1549,13 +1552,13 @@ class TanhPattern : public SPIRVToLLVMConversion { Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = - rewriter.create(loc, dstType, two, tanhOp.getOperand()); - Value exponential = rewriter.create(loc, dstType, multiplied); + LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); + Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = - rewriter.create(loc, dstType, exponential, one); + LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); Value denominator = - rewriter.create(loc, dstType, exponential, one); + LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); rewriter.replaceOpWithNewOp(tanhOp, dstType, numerator, denominator); return success(); @@ -1594,8 +1597,8 @@ class VariablePattern : public SPIRVToLLVMConversion { if (!elementType) return rewriter.notifyMatchFailure(varOp, "type conversion failed"); Value allocated = - rewriter.create(loc, dstType, elementType, size); - rewriter.create(loc, adaptor.getInitializer(), allocated); + LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } @@ -1656,7 +1659,7 @@ class FuncConversionPattern : public SPIRVToLLVMConversion { // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); - auto newFuncOp = rewriter.create(loc, name, llvmType); + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); @@ -1710,7 +1713,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion { ConversionPatternRewriter &rewriter) const override { auto newModuleOp = - rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName()); + ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName()); rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder @@ -1751,7 +1754,7 @@ class VectorShufflePattern auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); - Value targetOp = rewriter.create(loc, dstType); + Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { if (!isa(componentsArray[i])) return op.emitError("unable to support non-constant component"); @@ -1767,16 +1770,17 @@ class VectorShufflePattern baseVector = vector2; } - Value dstIndex = rewriter.create( - loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); - Value index = rewriter.create( - loc, llvmI32Type, + Value dstIndex = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, + rewriter.getIntegerAttr(rewriter.getI32Type(), i)); + Value index = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); - auto extractOp = rewriter.create( - loc, scalarType, baseVector, index); - targetOp = rewriter.create(loc, dstType, targetOp, - extractOp, dstIndex); + auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType, + baseVector, index); + targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp, + extractOp, dstIndex); } rewriter.replaceOp(op, targetOp); return success(); diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp index da9ad3dd67328..245e60b04ec31 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -32,7 +32,7 @@ class ConvertCstrRequireOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { - rewriter.create(op.getLoc(), op.getPred(), op.getMsgAttr()); + cf::AssertOp::create(rewriter, op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index bbe1490137bf8..7025c5a7daf93 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -82,40 +82,40 @@ struct BroadcastOpConverter : public OpConversionPattern { // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create(1); + Value one = arith::ConstantIndexOp::create(lb, 1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = lb.create(arith::CmpIPredicate::ult, - outputDimension, rankDiff); + Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult, + outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = - lb.create( - outOfBounds, - [&](OpBuilder &b, Location loc) { - b.create(loc, broadcastedDim); - }, - [&](OpBuilder &b, Location loc) { - // The broadcasting logic is: - // - if one extent (here we arbitrarily choose the - // extent from the greater-rank operand) is equal to 1, - // then take the extent from the other operand - // - otherwise, take the extent as-is. - // Note that this logic remains correct in the presence - // of dimensions of zero extent. - Value lesserRankOperandDimension = b.create( - loc, indexTy, outputDimension, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, shape, ValueRange{lesserRankOperandDimension}); - - Value dimIsOne = - b.create(loc, arith::CmpIPredicate::eq, - lesserRankOperandExtent, one); - Value dim = b.create( - loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); - b.create(loc, dim); - }) + IfOp::create( + lb, outOfBounds, + [&](OpBuilder &b, Location loc) { + scf::YieldOp::create(b, loc, broadcastedDim); + }, + [&](OpBuilder &b, Location loc) { + // The broadcasting logic is: + // - if one extent (here we arbitrarily choose the + // extent from the greater-rank operand) is equal to 1, + // then take the extent from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence + // of dimensions of zero extent. + Value lesserRankOperandDimension = arith::SubIOp::create( + b, loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = tensor::ExtractOp::create( + b, loc, shape, ValueRange{lesserRankOperandDimension}); + + Value dimIsOne = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + lesserRankOperandExtent, one); + Value dim = arith::SelectOp::create( + b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); + scf::YieldOp::create(b, loc, dim); + }) .getResult(0); } return broadcastedDim; @@ -133,7 +133,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); + Value zero = arith::ConstantIndexOp::create(lb, 0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -141,31 +141,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); - Value replacement = lb.create( - getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + Value replacement = tensor::GenerateOp::create( + lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value broadcastedDim = getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, args[0]); - b.create(loc, broadcastedDim); + tensor::YieldOp::create(b, loc, broadcastedDim); }); if (replacement.getType() != op.getType()) - replacement = lb.create(op.getType(), replacement); + replacement = tensor::CastOp::create(lb, op.getType(), replacement); rewriter.replaceOp(op, replacement); return success(); } @@ -193,13 +193,13 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite( auto loc = op.getLoc(); SmallVector extentOperands; for (auto extent : op.getShape()) { - extentOperands.push_back( - rewriter.create(loc, extent.getLimitedValue())); + extentOperands.push_back(arith::ConstantIndexOp::create( + rewriter, loc, extent.getLimitedValue())); } Type resultTy = RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); Value tensor = - rewriter.create(loc, resultTy, extentOperands); + tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands); rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); } @@ -245,8 +245,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); - Value one = lb.create(1); + Value zero = arith::ConstantIndexOp::create(lb, 0); + Value one = arith::ConstantIndexOp::create(lb, 1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -254,26 +254,26 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); - Value trueVal = - rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty, + rewriter.getBoolAttr(true)); - auto reduceResult = lb.create( - loc, zero, maxRank, one, ValueRange{trueVal}, + auto reduceResult = ForOp::create( + lb, loc, zero, maxRank, one, ValueRange{trueVal}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { // Find a non-1 dim, if it exists. Note that the first part of this // could reuse the Broadcast lowering entirely, but we redo the work @@ -285,38 +285,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = b.create( - loc, arith::CmpIPredicate::ult, iv, rankDiff); + Value outOfBounds = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = - b.create( - loc, outOfBounds, - [&](OpBuilder &b, Location loc) { - // Non existent dimensions are always broadcastable - b.create(loc, broadcastable); - }, - [&](OpBuilder &b, Location loc) { - // Every value needs to be either 1, or the same non-1 - // value to be broadcastable in this dim. - Value operandDimension = - b.create(loc, indexTy, iv, rankDiff); - Value dimensionExtent = b.create( - loc, shape, ValueRange{operandDimension}); - - Value equalOne = b.create( - loc, arith::CmpIPredicate::eq, dimensionExtent, one); - Value equalBroadcasted = b.create( - loc, arith::CmpIPredicate::eq, dimensionExtent, - broadcastedDim); - Value result = b.create( - loc, broadcastable, - b.create(loc, equalOne, - equalBroadcasted)); - b.create(loc, result); - }) + IfOp::create( + b, loc, outOfBounds, + [&](OpBuilder &b, Location loc) { + // Non existent dimensions are always broadcastable + scf::YieldOp::create(b, loc, broadcastable); + }, + [&](OpBuilder &b, Location loc) { + // Every value needs to be either 1, or the same non-1 + // value to be broadcastable in this dim. + Value operandDimension = + arith::SubIOp::create(b, loc, indexTy, iv, rankDiff); + Value dimensionExtent = tensor::ExtractOp::create( + b, loc, shape, ValueRange{operandDimension}); + + Value equalOne = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, dimensionExtent, one); + Value equalBroadcasted = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + dimensionExtent, broadcastedDim); + Value result = arith::AndIOp::create( + b, loc, broadcastable, + arith::OrIOp::create(b, loc, equalOne, + equalBroadcasted)); + scf::YieldOp::create(b, loc, result); + }) .getResult(0); } - b.create(loc, broadcastable); + scf::YieldOp::create(b, loc, broadcastable); }); rewriter.replaceOp(op, reduceResult.getResults().front()); @@ -339,7 +339,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor, // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further // lowerings. This can be further optimized if needed to avoid intermediate // steps. - auto shapeOf = rewriter.create(op.getLoc(), op.getValue()); + auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue()); rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, op.getIndex()); return success(); @@ -421,16 +421,17 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create(loc, indexTy, adaptor.getShape(), zero); + tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero); - auto loop = rewriter.create( - loc, zero, rank, one, op.getInitVals(), + auto loop = scf::ForOp::create( + rewriter, loc, zero, rank, one, op.getInitVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = b.create(loc, adaptor.getShape(), iv); + Value extent = + tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -444,7 +445,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); - b.create(loc, mappedResults); + scf::YieldOp::create(b, loc, mappedResults); }); rewriter.replaceOp(op, loop.getResults()); @@ -507,44 +508,44 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value firstShape = adaptor.getShapes().front(); Value firstRank = - rewriter.create(loc, indexTy, firstShape, zero); + tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.getShapes().drop_front(1)) { - Value rank = rewriter.create(loc, indexTy, shape, zero); - Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, - firstRank, rank); - auto same = rewriter.create( - loc, eqRank, + Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero); + Value eqRank = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank); + auto same = IfOp::create( + rewriter, loc, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); + Value one = arith::ConstantIndexOp::create(b, loc, 1); Value init = - b.create(loc, i1Ty, b.getBoolAttr(true)); - auto loop = b.create( - loc, zero, firstRank, one, ValueRange{init}, + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true)); + auto loop = scf::ForOp::create( + b, loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value conj = args[0]; Value lhsExtent = - b.create(loc, firstShape, iv); - Value rhsExtent = b.create(loc, shape, iv); - Value eqExtent = b.create( - loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); - b.create(loc, ValueRange({conjNext})); + tensor::ExtractOp::create(b, loc, firstShape, iv); + Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv); + Value eqExtent = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); + Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent); + scf::YieldOp::create(b, loc, ValueRange({conjNext})); }); - b.create(loc, loop.getResults()); + scf::YieldOp::create(b, loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { Value result = - b.create(loc, i1Ty, b.getBoolAttr(false)); - b.create(loc, result); + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false)); + scf::YieldOp::create(b, loc, result); }); result = !result ? same.getResult(0) - : rewriter.create(loc, result, - same.getResult(0)); + : arith::AndIOp::create(rewriter, loc, result, + same.getResult(0)); } rewriter.replaceOp(op, result); return success(); @@ -581,18 +582,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { - Value extent = rewriter.create(loc, tensor, i); + Value extent = tensor::DimOp::create(rewriter, loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = rewriter.create( - loc, rankedTensorTy.getDimSize(i)); + Value extent = arith::ConstantIndexOp::create( + rewriter, loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } // Materialize extent tensor. - Value staticExtentTensor = rewriter.create( - loc, RankedTensorType::get({rank}, rewriter.getIndexType()), + Value staticExtentTensor = tensor::FromElementsOp::create( + rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()), extentValues); rewriter.replaceOpWithNewOp(op, op.getType(), staticExtentTensor); @@ -601,13 +602,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); - Value rank = rewriter.create(loc, tensor); + Value rank = tensor::RankOp::create(rewriter, loc, tensor); rewriter.replaceOpWithNewOp( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); - Value extent = b.create(loc, tensor, dim); - b.create(loc, extent); + Value extent = tensor::DimOp::create(b, loc, tensor, dim); + tensor::YieldOp::create(b, loc, extent); }); return success(); @@ -634,22 +635,22 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(0); - Value rank = b.create(adaptor.getOperand(), zero); + Value zero = arith::ConstantIndexOp::create(b, 0); + Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.getIndex(); - Value add = b.create(originalIndex, rank); + Value add = arith::AddIOp::create(b, originalIndex, rank); Value indexIsNegative = - b.create(arith::CmpIPredicate::slt, originalIndex, zero); - Value index = b.create(indexIsNegative, add, originalIndex); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero); + Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex); - Value one = b.create(1); + Value one = arith::ConstantIndexOp::create(b, 1); Value head = - b.create(adaptor.getOperand(), zero, index, one); - Value tailSize = b.create(rank, index); - Value tail = b.create(adaptor.getOperand(), index, - tailSize, one); + tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one); + Value tailSize = arith::SubIOp::create(b, rank, index); + Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index, + tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); } diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index 2c4d27502a521..f24972f6b6ee1 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -68,10 +68,10 @@ class TensorExtractPattern final // We could use the initializer directly; but certain driver compilers // have bugs dealing with that. So for now, use spirv.Store for // initialization. - varOp = rewriter.create(loc, varType, - spirv::StorageClass::Function, - /*initializer=*/nullptr); - rewriter.create(loc, varOp, adaptor.getTensor()); + varOp = spirv::VariableOp::create(rewriter, loc, varType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); + spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. @@ -83,7 +83,7 @@ class TensorExtractPattern final Value index = spirv::linearizeIndex(adaptor.getIndices(), strides, /*offset=*/0, indexType, loc, rewriter); - auto acOp = rewriter.create(loc, varOp, index); + auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 40ad63610e23f..044b725c7d805 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { Value getConstantValue(Location loc, Type type, int64_t value, PatternRewriter &rewriter) { - return rewriter.create( - loc, getConstantAttr(type, value, rewriter)); + return arith::ConstantOp::create(rewriter, loc, + getConstantAttr(type, value, rewriter)); } // This converts the TOSA ApplyScale operator to a set of arithmetic ops, @@ -82,41 +82,41 @@ class ApplyScaleGenericOpConverter Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); - Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = value; if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) - value64 = rewriter.create(loc, i64Ty, value); + value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value); Value multiplier64 = - rewriter.create(loc, i64Ty, multiplier32); + arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32); Value multiply64 = - rewriter.create(loc, value64, multiplier64); + arith::MulIOp::create(rewriter, loc, value64, multiplier64); // Apply normal rounding. - Value shift64 = rewriter.create(loc, i64Ty, shift32); - Value round = rewriter.create(loc, one64, shift64); - round = rewriter.create(loc, round, one64); - multiply64 = rewriter.create(loc, multiply64, round); + Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32); + Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64); + round = arith::ShRUIOp::create(rewriter, loc, round, one64); + multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round); // Apply double rounding if necessary. if (op.getRoundingMode() == "DOUBLE_ROUND") { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); - Value positive = rewriter.create( - loc, arith::CmpIPredicate::sge, value, zero); + Value positive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value, zero); Value dir = - rewriter.create(loc, positive, roundUp, roundDown); - Value val = rewriter.create(loc, dir, multiply64); - Value valid = rewriter.create( - loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); + arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown); + Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64); + Value valid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); multiply64 = - rewriter.create(loc, valid, val, multiply64); + arith::SelectOp::create(rewriter, loc, valid, val, multiply64); } - Value result64 = rewriter.create(loc, multiply64, shift64); - Value result32 = rewriter.create(loc, i32Ty, result64); + Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64); + Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64); rewriter.replaceOp(op, result32); return success(); @@ -146,7 +146,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { Value value32 = op.getValue(); Value multiplier32 = op.getMultiplier(); - Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Constants used during the scaling operation. Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); @@ -158,86 +158,87 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { // Compute the multiplication in 64-bits then select the high / low parts. // Grab out the high/low of the computation auto value64 = - rewriter.create(loc, value32, multiplier32); + arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32); Value low32 = value64.getLow(); Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. - Value shiftOver32 = rewriter.create( - loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - Value roundHighBits = rewriter.create( - loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); + Value shiftOver32 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); + Value roundHighBits = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); Value shiftHighL = - rewriter.create(loc, thirtyTwo32, shift32); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32); Value shiftHighR = - rewriter.create(loc, shift32, thirtyTwo32); + arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32); shiftHighL = - rewriter.create(loc, shiftOver32, zero32, shiftHighL); + arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL); shiftHighR = - rewriter.create(loc, shiftOver32, shiftHighR, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. if (op.getRoundingMode() == "DOUBLE_ROUND") { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); - Value valuePositive = rewriter.create( - loc, arith::CmpIPredicate::sge, value32, zero32); + Value valuePositive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value32, zero32); - Value roundDir = - rewriter.create(loc, valuePositive, one32, negOne32); + Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive, + one32, negOne32); roundDir = - rewriter.create(loc, shiftOver32, roundDir, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32); - Value shiftLow = rewriter.create(loc, low32, thirty32); - Value rounded = rewriter.create(loc, shiftLow, roundDir); - Value carry = rewriter.create(loc, rounded, two32); + Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32); + Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir); + Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32); Value shiftRound = - rewriter.create(loc, roundDir, thirty32); + arith::ShLIOp::create(rewriter, loc, roundDir, thirty32); - low32 = rewriter.create(loc, low32, shiftRound); - high32 = rewriter.create(loc, high32, carry); + low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound); + high32 = arith::AddIOp::create(rewriter, loc, high32, carry); } // Conditionally apply rounding in the low bits. { - Value shiftSubOne = rewriter.create(loc, shift32, one32); - Value roundBit = rewriter.create(loc, one32, shiftSubOne); - roundBit = rewriter.create(loc, roundHighBits, zero32, - roundBit); - - Value newLow32 = rewriter.create(loc, low32, roundBit); - Value wasRounded = rewriter.create( - loc, arith::CmpIPredicate::ugt, low32, newLow32); + Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32, + roundBit); + + Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit); + Value wasRounded = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32); low32 = newLow32; - Value rounded32 = rewriter.create(loc, i32Ty, wasRounded); - high32 = rewriter.create(loc, high32, rounded32); + Value rounded32 = + arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded); + high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32); } // Conditionally apply rounding in the high bits. { Value shiftSubOne = - rewriter.create(loc, shiftHighR, one32); - Value roundBit = rewriter.create(loc, one32, shiftSubOne); - roundBit = rewriter.create(loc, roundHighBits, roundBit, - zero32); - high32 = rewriter.create(loc, high32, roundBit); + arith::SubIOp::create(rewriter, loc, shiftHighR, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit, + zero32); + high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit); } // Combine the correct high/low bits into the final rescale result. - high32 = rewriter.create(loc, high32, shiftHighL); - high32 = rewriter.create(loc, high32, shiftHighR); - low32 = rewriter.create(loc, low32, shift32); - low32 = rewriter.create(loc, shiftOver32, zero32, low32); + high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL); + high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR); + low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32); + low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32); // Apply the rounding behavior and shift to the final alignment. - Value result = rewriter.create(loc, low32, high32); + Value result = arith::AddIOp::create(rewriter, loc, low32, high32); // Truncate if necessary. if (!getElementTypeOrSelf(resultTy).isInteger(32)) { - result = rewriter.create(loc, resultTy, result); + result = arith::TruncIOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 2f608bbd637b4..ec55091cd7eb8 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -70,14 +70,14 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, return result; // Unordered comparison of NaN against itself will always return true. - Value lhsIsNaN = rewriter.create( - op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs); - Value rhsIsNaN = rewriter.create( - op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs); + Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), + arith::CmpFPredicate::UNO, lhs, lhs); + Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), + arith::CmpFPredicate::UNO, rhs, rhs); Value rhsOrResult = - rewriter.create(op.getLoc(), lhsIsNaN, rhs, result); - return rewriter.create(op.getLoc(), rhsIsNaN, lhs, - rhsOrResult); + arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result); + return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs, + rhsOrResult); } static Value createLinalgBodyCalculationForElementwiseOp( @@ -89,38 +89,38 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::AbsOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::AbsFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) { - auto zero = rewriter.create( - loc, rewriter.getZeroAttr(elementTy)); - auto neg = rewriter.create(loc, zero, args[0]); - return rewriter.create(loc, args[0], neg); + auto zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getZeroAttr(elementTy)); + auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]); + return arith::MaxSIOp::create(rewriter, loc, args[0], neg); } // tosa::AddOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AddFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AddIOp::create(rewriter, loc, resultTypes, args); // tosa::SubOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::SubFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::SubIOp::create(rewriter, loc, resultTypes, args); // tosa::IntDivOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::DivSIOp::create(rewriter, loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && isa(elementTy)) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create(loc, resultTypes, one, args[0]); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]); } // tosa::MulOp @@ -140,7 +140,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( "Cannot have shift value for float"); return nullptr; } - return rewriter.create(loc, resultTypes, args[0], args[1]); + return arith::MulFOp::create(rewriter, loc, resultTypes, args[0], + args[1]); } if (isa(elementTy)) { @@ -149,21 +150,21 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (shift > 0) { auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); + arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); + a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); + b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b); - auto result = rewriter.create( - loc, rewriter.getI32Type(), a, b, shiftConst, + auto result = tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getStringAttr("SINGLE_ROUND")); if (elementTy.isInteger(32)) return result; - return rewriter.create(loc, elementTy, result); + return arith::TruncIOp::create(rewriter, loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -171,11 +172,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); + a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b); - return rewriter.create(loc, resultTypes, a, b); + return arith::MulIOp::create(rewriter, loc, resultTypes, a, b); } } @@ -201,14 +202,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( int64_t outZp = *maybeOutZp; if (isa(elementTy)) - return rewriter.create(loc, resultTypes, args[0]); + return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa(elementTy)) { if (!inZp && !outZp) { - auto constant = rewriter.create( - loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create(loc, resultTypes, constant, - args[0]); + auto constant = arith::ConstantOp::create( + rewriter, loc, IntegerAttr::get(elementTy, 0)); + return arith::SubIOp::create(rewriter, loc, resultTypes, constant, + args[0]); } // Compute the maximum value that can occur in the intermediate buffer. @@ -231,214 +232,214 @@ static Value createLinalgBodyCalculationForElementwiseOp( } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create( - loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + Value zpAddValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = - rewriter.create(loc, intermediateType, args[0]); - auto sub = rewriter.create(loc, zpAddValue, ext); + arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]); + auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext); // Clamp to the negation range. - Value min = rewriter.create( - loc, intermediateType, + Value min = arith::ConstantIntOp::create( + rewriter, loc, intermediateType, APInt::getSignedMinValue(inputBitWidth).getSExtValue()); - Value max = rewriter.create( - loc, intermediateType, + Value max = arith::ConstantIntOp::create( + rewriter, loc, intermediateType, APInt::getSignedMaxValue(inputBitWidth).getSExtValue()); auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false); // Truncate to the final value. - return rewriter.create(loc, elementTy, clamp); + return arith::TruncIOp::create(rewriter, loc, elementTy, clamp); } } // tosa::BitwiseAndOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create(loc, allOnesAttr); - return rewriter.create(loc, resultTypes, args[0], allOnes); + auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::ShLIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::ShRUIOp::create(rewriter, loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && isa(elementTy)) { - auto result = rewriter.create(loc, resultTypes, args); + auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args); auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; } Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); - auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); - auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto one = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(elementTy, 1)); + auto zero = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = rewriter.create( - loc, arith::CmpIPredicate::sgt, args[1], zero); + auto shiftValueGreaterThanZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create(loc, resultTypes, args[1], one); + arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one); auto shifted = - rewriter.create(loc, resultTypes, args[0], subtract) + arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract) ->getResults(); - auto truncated = rewriter.create( - loc, i1Ty, shifted, ArrayRef()); + auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted, + ArrayRef()); auto isInputOdd = - rewriter.create(loc, i1Ty, truncated, i1one); + arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create( - loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); + auto shouldRound = arith::AndIOp::create( + rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create(loc, resultTypes, shouldRound); - return rewriter.create(loc, resultTypes, result, extended); + arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound); + return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, elementTy, args[0]); + return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]); } // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { - auto one = rewriter.create( - loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create(loc, resultTypes, args[0], one); + auto one = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(elementTy, 1)); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::PowOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args); // tosa::RsqrtOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args); // tosa::LogOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::LogOp::create(rewriter, loc, resultTypes, args); // tosa::ExpOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args); // tosa::SinOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::SinOp::create(rewriter, loc, resultTypes, args); // tosa::CosOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::CosOp::create(rewriter, loc, resultTypes, args); // tosa::TanhOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args); // tosa::ErfOp if (isa(op) && llvm::isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OGT, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::sgt, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, + args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OGE, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::sge, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, + args[0], args[1]); // tosa::EqualOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OEQ, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::eq, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + args[0], args[1]); // tosa::SelectOp if (isa(op)) { elementTy = cast(op->getOperand(1).getType()).getElementType(); if (isa(elementTy) || isa(elementTy)) - return rewriter.create(loc, args[0], args[1], args[2]); + return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && isa(elementTy)) { - auto max = rewriter.create(loc, args[0], args[1]); + auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], max); } if (isa(op) && elementTy.isSignlessInteger()) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && isa(elementTy)) { - auto min = rewriter.create(loc, args[0], args[1]); + auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], min); } if (isa(op) && elementTy.isSignlessInteger()) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::CeilOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::CeilOp::create(rewriter, loc, resultTypes, args); // tosa::FloorOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::FloorOp::create(rewriter, loc, resultTypes, args); // tosa::ClampOp if (isa(op) && isa(elementTy)) { @@ -449,10 +450,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); - auto min = rewriter.create( - loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); - auto max = rewriter.create( - loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); + auto min = arith::ConstantOp::create( + rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); + auto max = arith::ConstantOp::create( + rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); auto result = clampFloatHelper(loc, args[0], min, max, rewriter); auto clampOp = llvm::cast(op); @@ -478,11 +479,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // return init if x == NaN else result // Unordered comparison of NaN against itself will always return true. - Value isNaN = rewriter.create( - op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); + Value isNaN = arith::CmpFOp::create( + rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); // TOSA specifies that in "ignore" NaN mode the result is "min" if the input // is NaN. - return rewriter.create(op->getLoc(), isNaN, min, result); + return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result); } if (isa(op) && isa(elementTy)) { @@ -515,10 +516,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( min = std::min(min, maxRepresentable); max = std::min(max, maxRepresentable); - auto minVal = rewriter.create( - loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = rewriter.create( - loc, max, intTy.getIntOrFloatBitWidth()); + auto minVal = arith::ConstantIntOp::create(rewriter, loc, min, + intTy.getIntOrFloatBitWidth()); + auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max, + intTy.getIntOrFloatBitWidth()); return clampIntHelper(loc, args[0], minVal, maxVal, rewriter, intTy.isUnsignedInteger()); } @@ -526,11 +527,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::SigmoidOp if (isa(op) && isa(elementTy)) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create(loc, resultTypes, args[0]); - auto exp = rewriter.create(loc, resultTypes, negate); - auto added = rewriter.create(loc, resultTypes, exp, one); - return rewriter.create(loc, resultTypes, one, added); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); + auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate); + auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, added); } // tosa::CastOp @@ -549,21 +550,21 @@ static Value createLinalgBodyCalculationForElementwiseOp( return args.front(); if (isa(srcTy) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::ExtFOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::TruncFOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::UIToFPOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::ExtUIOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. @@ -574,25 +575,25 @@ static Value createLinalgBodyCalculationForElementwiseOp( loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); - return rewriter.create(loc, resultTypes[0], - unrealizedCast); + return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], + unrealizedCast); } // All other si-to-fp conversions should be handled by SIToFP. if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::SIToFPOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); // Casting to boolean, floats need to only be checked as not-equal to zero. if (isa(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create( - loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create(loc, arith::CmpFPredicate::UNE, - args.front(), zero); + Value zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getFloatAttr(srcTy, 0.0)); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE, + args.front(), zero); } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto rounded = rewriter.create(loc, args[0]); + auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]); const auto &fltSemantics = cast(srcTy).getFloatSemantics(); // Check whether neither int min nor int max can be represented in the @@ -601,37 +602,42 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::semanticsMaxExponent(fltSemantics)) { // Use cmp + select to replace infinites by int min / int max. Other // integral values can be represented in the integer space. - auto conv = rewriter.create(loc, dstTy, rounded); - auto posInf = rewriter.create( - loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), - APFloat::getInf(fltSemantics))); - auto negInf = rewriter.create( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APFloat::getInf(fltSemantics, /*Negative=*/true))); - auto overflow = rewriter.create( - loc, arith::CmpFPredicate::UEQ, rounded, posInf); - auto underflow = rewriter.create( - loc, arith::CmpFPredicate::UEQ, rounded, negInf); - auto intMin = rewriter.create( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); - auto intMax = rewriter.create( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); + auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded); + auto posInf = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), + APFloat::getInf(fltSemantics))); + auto negInf = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APFloat::getInf(fltSemantics, /*Negative=*/true))); + auto overflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf); + auto underflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf); + auto intMin = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); + auto intMax = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto maxClamped = - rewriter.create(loc, overflow, intMax, conv); - return rewriter.create(loc, underflow, intMin, - maxClamped); + arith::SelectOp::create(rewriter, loc, overflow, intMax, conv); + return arith::SelectOp::create(rewriter, loc, underflow, intMin, + maxClamped); } - auto intMinFP = rewriter.create( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue())); + auto intMinFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); // Check whether the mantissa has enough bits to represent int max. if (cast(srcTy).getFPMantissaWidth() >= @@ -640,58 +646,61 @@ static Value createLinalgBodyCalculationForElementwiseOp( // consists of a single leading bit. Therefore we can clamp the input // in the floating-point domain. - auto intMaxFP = rewriter.create( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue())); + auto intMaxFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); Value clamped = clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); - return rewriter.create(loc, dstTy, clamped); + return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped); } // Due to earlier check we know exponant range is big enough to represent // int min. We can therefore rely on int max + 1 being representable as // well because it's just int min with a positive sign. So clamp the min // value and compare against that to select the max int value if needed. - auto intMaxPlusOneFP = rewriter.create( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - static_cast( - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue()) + - 1.0f)); - - auto intMax = rewriter.create( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); + auto intMaxPlusOneFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + static_cast( + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue()) + + 1.0f)); + + auto intMax = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto minClampedFP = - rewriter.create(loc, rounded, intMinFP); + arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP); auto minClamped = - rewriter.create(loc, dstTy, minClampedFP); - auto overflow = rewriter.create( - loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); - return rewriter.create(loc, overflow, intMax, - minClamped); + arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP); + auto overflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); + return arith::SelectOp::create(rewriter, loc, overflow, intMax, + minClamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (isa(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create( - loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create(loc, arith::CmpIPredicate::ne, - args.front(), zero); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, + srcTy.getIntOrFloatBitWidth()); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, + args.front(), zero); } if (isa(srcTy) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, - ArrayRef()); + return arith::ExtSIOp::create(rewriter, loc, resultTypes, args, + ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) { - return rewriter.create(loc, dstTy, args[0]); + return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]); } } @@ -710,14 +719,14 @@ static Value createIndex(PatternRewriter &rewriter, Location loc, auto [it, inserted] = indexPool.try_emplace(index); if (inserted) it->second = - rewriter.create(loc, rewriter.getIndexAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index)); return it->second; } static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto indexValue = createIndex(rewriter, loc, indexPool, index); - return rewriter.create(loc, tensor, indexValue).getResult(); + return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult(); } static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, @@ -783,7 +792,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { auto nextSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); - targetSize = rewriter.create(loc, targetSize, nextSize); + targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize); } return {targetSize, nullptr}; } @@ -838,8 +847,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Check if broadcast is necessary auto one = createIndex(rewriter, loc, indexPool, 1); auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); - auto broadcastNecessary = rewriter.create( - loc, arith::CmpIPredicate::eq, runtimeSize, one); + auto broadcastNecessary = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one); // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { @@ -855,8 +864,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, operand, index); outputTensorShape.push_back(size); } - Value outputTensor = opBuilder.create( - loc, outputTensorShape, rankedTensorType.getElementType()); + Value outputTensor = tensor::EmptyOp::create( + opBuilder, loc, outputTensorShape, rankedTensorType.getElementType()); // Emit 'linalg.generic' op auto resultTensor = @@ -866,7 +875,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { // Emit 'linalg.yield' op - opBuilder.create(loc, blockArgs.front()); + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); }) .getResult(0); @@ -875,17 +884,17 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, loc, operand.getType(), resultTensor); // Emit 'scf.yield' op - opBuilder.create(loc, castResultTensor); + scf::YieldOp::create(opBuilder, loc, castResultTensor); }; // Emit 'else' region of 'scf.if' auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { - opBuilder.create(loc, operand); + scf::YieldOp::create(opBuilder, loc, operand); }; // Emit 'scf.if' op - auto ifOp = rewriter.create(loc, broadcastNecessary, - emitThenRegion, emitElseRegion); + auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary, + emitThenRegion, emitElseRegion); return ifOp.getResult(0); } @@ -930,8 +939,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, if (!resultType) { return rewriter.notifyMatchFailure(operation, "failed to convert type"); } - Value outputTensor = rewriter.create( - loc, targetShape, resultType.getElementType()); + Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape, + resultType.getElementType()); // Create affine maps. Input affine maps broadcast static dimensions of size // 1. The output affine map is an identity map. @@ -957,8 +966,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op bool encounteredError = false; - auto linalgOp = rewriter.create( - loc, outputTensor.getType(), operands, outputTensor, affineMaps, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( @@ -968,7 +977,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, encounteredError = true; return; } - opBuilder.create(loc, opResult); + linalg::YieldOp::create(opBuilder, loc, opResult); }); if (encounteredError) return rewriter.notifyMatchFailure( @@ -1078,42 +1087,42 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::AddFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::AddIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::MulFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::MulIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return arith::AndIOp::create(rewriter, loc, args); if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return arith::OrIOp::create(rewriter, loc, args); return {}; } @@ -1139,7 +1148,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); if (inputTy.isDynamicDim(i)) - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1158,7 +1167,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) @@ -1176,7 +1185,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // Additionally we have to keep track of whether we've seen any non-NaN // values and then do a final select based on this predicate. auto trueAttr = rewriter.getBoolAttr(true); - auto trueValue = rewriter.create(loc, trueAttr); + auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = rewriter .create(loc, reduceShape, trueValue.getType(), @@ -1202,8 +1211,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, } bool didEncounterError = false; - linalg::LinalgOp linalgOp = rewriter.create( - loc, inputs, outputs, axis, + linalg::LinalgOp linalgOp = linalg::ReduceOp::create( + rewriter, loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { std::array binaryArgs{ blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; @@ -1219,21 +1228,22 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto oldAllResultsNanFlagValue = blockArgs[3]; // Unordered comparison of NaN against itself will always return true. - Value isNaN = nestedBuilder.create( - op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue); + Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(), + arith::CmpFPredicate::UNO, + inputValue, inputValue); // If we've encountered a NaN, take the non-NaN value. - auto selectOp = nestedBuilder.create( - op->getLoc(), isNaN, initialValue, result); + auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(), + isNaN, initialValue, result); // Update the flag which keeps track of whether we have seen a non-NaN // value. - auto newAllResultsNanFlagValue = nestedBuilder.create( - op->getLoc(), oldAllResultsNanFlagValue, isNaN); + auto newAllResultsNanFlagValue = arith::AndIOp::create( + nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN); resultsToYield.push_back(selectOp); resultsToYield.push_back(newAllResultsNanFlagValue); } else { resultsToYield.push_back(result); } - nestedBuilder.create(loc, resultsToYield); + linalg::YieldOp::create(nestedBuilder, loc, resultsToYield); }); if (!didEncounterError) @@ -1250,7 +1260,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto nanValueAttr = rewriter.getFloatAttr( elementTy, APFloat::getNaN(cast(elementTy).getFloatSemantics(), false)); - auto nanValue = rewriter.create(loc, nanValueAttr); + auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = rewriter .create(loc, reduceShape, @@ -1278,7 +1288,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, ins.push_back(linalgOp->getResult(0)); outs.push_back(finalEmptyTensor); auto linalgSelect = - rewriter.create(op->getLoc(), ins, outs); + linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs); linalgOp = linalgSelect; } @@ -1350,7 +1360,7 @@ class RescaleConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1401,16 +1411,17 @@ class RescaleConverter : public OpRewritePattern { Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create( - loc, rewriter.getI32IntegerAttr(multiplierValues.front())); + multiplierConstant = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ rewriter.getAffineDimExpr(rank - 1)}; auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create( - loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); + genericInputs.push_back(arith::ConstantOp::create( + rewriter, loc, + DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, multiplierExprs, @@ -1424,16 +1435,16 @@ class RescaleConverter : public OpRewritePattern { Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create( - loc, rewriter.getI8IntegerAttr(shiftValues.front())); + shiftConstant = arith::ConstantOp::create( + rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { rewriter.getAffineDimExpr(rank - 1)}; auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create( - loc, DenseIntElementsAttr::get(shiftType, shiftValues))); + genericInputs.push_back(arith::ConstantOp::create( + rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, rewriter.getContext())); @@ -1444,13 +1455,13 @@ class RescaleConverter : public OpRewritePattern { indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. - Value emptyTensor = rewriter.create( - loc, outputTy.getShape(), outputTy.getElementType(), + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, outputTy.getShape(), outputTy.getElementType(), ArrayRef({dynDims})); - auto linalgOp = rewriter.create( - loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, - getNParallelLoopsAttrs(rank), + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor}, + indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; @@ -1466,9 +1477,10 @@ class RescaleConverter : public OpRewritePattern { const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); // Extend zeropoint for sub-32bits widths. const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = nestedBuilder.create( - loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = arith::ConstantOp::create( + nestedBuilder, loc, + IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), + *maybeIZp)); FailureOr maybeOZp = op.getOutputZeroPoint(); if (failed(maybeOZp)) { @@ -1482,9 +1494,10 @@ class RescaleConverter : public OpRewritePattern { unsigned outBitWidth = outIntType.getWidth(); const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = nestedBuilder.create( - loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); + auto outputZp = arith::ConstantOp::create( + nestedBuilder, loc, + IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), + *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; @@ -1501,24 +1514,24 @@ class RescaleConverter : public OpRewritePattern { } if (valueTy.getIntOrFloatBitWidth() < 32) { if (op.getInputUnsigned()) { - value = nestedBuilder.create( - nestedLoc, nestedBuilder.getI32Type(), value); + value = arith::ExtUIOp::create(nestedBuilder, nestedLoc, + nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create( - nestedLoc, nestedBuilder.getI32Type(), value); + value = arith::ExtSIOp::create(nestedBuilder, nestedLoc, + nestedBuilder.getI32Type(), value); } } value = - nestedBuilder.create(nestedLoc, value, inputZp); + arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp); - value = nestedBuilder.create( - loc, nestedBuilder.getI32Type(), value, multiplier, shift, - roundingMode); + value = tosa::ApplyScaleOp::create(nestedBuilder, loc, + nestedBuilder.getI32Type(), value, + multiplier, shift, roundingMode); // Move to the new zero-point. value = - nestedBuilder.create(nestedLoc, value, outputZp); + arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp); // Saturate to the output size. int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); @@ -1530,18 +1543,18 @@ class RescaleConverter : public OpRewritePattern { intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create( - loc, nestedBuilder.getI32IntegerAttr(intMin)); - auto intMaxVal = nestedBuilder.create( - loc, nestedBuilder.getI32IntegerAttr(intMax)); + auto intMinVal = arith::ConstantOp::create( + nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin)); + auto intMaxVal = arith::ConstantOp::create( + nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax)); value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, nestedBuilder, /*isUnsigned=*/false); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create( - nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), - value); + value = arith::TruncIOp::create( + nestedBuilder, nestedLoc, + rewriter.getIntegerType(outIntType.getWidth()), value); } if (outIntType.isUnsignedInteger()) { @@ -1550,7 +1563,7 @@ class RescaleConverter : public OpRewritePattern { outIntType, value) .getResult(0); } - nestedBuilder.create(loc, value); + linalg::YieldOp::create(nestedBuilder, loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); @@ -1608,48 +1621,49 @@ class ResizeUnaryConverter : public OpRewritePattern { auto collapseTy = RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, inputTy.getElementType()); - Value collapse = builder.create(collapseTy, input, - reassociationMap); + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input, + reassociationMap); // Get any dynamic shapes that appear in the input format. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); // Generate the elementwise operation for casting scaling the input value. auto genericTy = collapseTy.clone(resultTy.getElementType()); - Value empty = builder.create( - genericTy.getShape(), resultTy.getElementType(), outputDynSize); + Value empty = + tensor::EmptyOp::create(builder, genericTy.getShape(), + resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); SmallVector iterators(genericTy.getRank(), utils::IteratorType::parallel); - auto generic = builder.create( - genericTy, ValueRange{collapse}, ValueRange{empty}, + auto generic = linalg::GenericOp::create( + builder, genericTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{genericMap, genericMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; // This is the quantized case. if (inputTy.getElementType() != resultTy.getElementType()) { - value = - b.create(loc, resultTy.getElementType(), value); + value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(), + value); if (isBilinear && scale[0] != 0) { - Value scaleY = b.create( - loc, b.getI32IntegerAttr(scale[0])); - value = b.create(loc, value, scaleY); + Value scaleY = arith::ConstantOp::create( + b, loc, b.getI32IntegerAttr(scale[0])); + value = arith::MulIOp::create(b, loc, value, scaleY); } if (isBilinear && scale[2] != 0) { - Value scaleX = b.create( - loc, b.getI32IntegerAttr(scale[2])); - value = b.create(loc, value, scaleX); + Value scaleX = arith::ConstantOp::create( + b, loc, b.getI32IntegerAttr(scale[2])); + value = arith::MulIOp::create(b, loc, value, scaleX); } } - b.create(loc, value); + linalg::YieldOp::create(b, loc, value); }); rewriter.replaceOpWithNewOp( @@ -1697,9 +1711,9 @@ class MaterializeResizeBroadcast : public OpRewritePattern { resizeShape.push_back(channels); auto resizeTy = resultTy.clone(resizeShape); - auto resize = builder.create(resizeTy, input, op.getScale(), - op.getOffset(), op.getBorder(), - op.getMode()); + auto resize = + tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(), + op.getOffset(), op.getBorder(), op.getMode()); // Collapse an unit result dims. SmallVector reassociationMap(2); @@ -1720,20 +1734,20 @@ class MaterializeResizeBroadcast : public OpRewritePattern { collapseShape.push_back(channels); auto collapseTy = resultTy.clone(collapseShape); - Value collapse = builder.create(collapseTy, resize, - reassociationMap); + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, + resize, reassociationMap); // Broadcast the collapsed shape to the output result. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); SmallVector iterators(resultTy.getRank(), utils::IteratorType::parallel); - Value empty = builder.create( - resultTy.getShape(), resultTy.getElementType(), outputDynSize); + Value empty = tensor::EmptyOp::create( + builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize); SmallVector inputExprs{rewriter.getAffineDimExpr(0)}; if (inputH != 1) @@ -1751,7 +1765,7 @@ class MaterializeResizeBroadcast : public OpRewritePattern { ArrayRef{inputMap, outputMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; - b.create(loc, value); + linalg::YieldOp::create(b, loc, value); }); return success(); @@ -1789,10 +1803,10 @@ class GenericResizeConverter : public OpRewritePattern { SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto emptyTensor = b.create(resultTy.getShape(), resultETy, - *dynamicDimsOr); - auto genericOp = b.create( - resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, + auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(), + resultETy, *dynamicDimsOr); + auto genericOp = linalg::GenericOp::create( + b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); Value resize = genericOp.getResult(0); @@ -1800,19 +1814,21 @@ class GenericResizeConverter : public OpRewritePattern { OpBuilder::InsertionGuard regionGuard(b); b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({resultETy}), loc); - Value batch = b.create(0); - Value y = b.create(1); - Value x = b.create(2); - Value channel = b.create(3); + Value batch = linalg::IndexOp::create(b, 0); + Value y = linalg::IndexOp::create(b, 1); + Value x = linalg::IndexOp::create(b, 2); + Value channel = linalg::IndexOp::create(b, 3); Value zeroI32 = - b.create(b.getZeroAttr(b.getI32Type())); - Value zeroFp = b.create(b.getZeroAttr(floatTy)); - Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); - Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); + arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type())); + Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy)); + Value hMax = + arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1)); + Value wMax = + arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1)); - Value inY = b.create(b.getI32Type(), y); - Value inX = b.create(b.getI32Type(), x); + Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y); + Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x); SmallVector scale, offset, border; if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || @@ -1824,16 +1840,16 @@ class GenericResizeConverter : public OpRewritePattern { } Value yScaleN, yScaleD, xScaleN, xScaleD; - yScaleN = b.create(b.getI32IntegerAttr(scale[0])); - yScaleD = b.create(b.getI32IntegerAttr(scale[1])); - xScaleN = b.create(b.getI32IntegerAttr(scale[2])); - xScaleD = b.create(b.getI32IntegerAttr(scale[3])); + yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0])); + yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1])); + xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2])); + xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3])); Value yOffset, xOffset, yBorder, xBorder; - yOffset = b.create(b.getI32IntegerAttr(offset[0])); - xOffset = b.create(b.getI32IntegerAttr(offset[1])); - yBorder = b.create(b.getI32IntegerAttr(border[0])); - xBorder = b.create(b.getI32IntegerAttr(border[1])); + yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0])); + xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1])); + yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0])); + xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1])); // Compute the ix and dx values for both the X and Y dimensions. auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, @@ -1846,16 +1862,16 @@ class GenericResizeConverter : public OpRewritePattern { } // x = x * scale_d + offset; // ix = floor(x / scale_n) - Value val = b.create(in, scaleD); - val = b.create(val, offset); - index = b.create(val, scaleN); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::FloorDivSIOp::create(b, val, scaleN); // rx = x % scale_n // dx = rx / scale_n - Value r = b.create(val, scaleN); - Value rFp = b.create(floatTy, r); - Value scaleNfp = b.create(floatTy, scaleN); - delta = b.create(rFp, scaleNfp); + Value r = arith::RemSIOp::create(b, val, scaleN); + Value rFp = arith::SIToFPOp::create(b, floatTy, r); + Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN); + delta = arith::DivFOp::create(b, rFp, scaleNfp); }; // Compute the ix and dx values for the X and Y dimensions - int case. @@ -1870,11 +1886,11 @@ class GenericResizeConverter : public OpRewritePattern { // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; - Value val = b.create(in, scaleD); - val = b.create(val, offset); - index = b.create(val, scaleN); - delta = b.create(index, scaleN); - delta = b.create(val, delta); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::DivSIOp::create(b, val, scaleN); + delta = arith::MulIOp::create(b, index, scaleN); + delta = arith::SubIOp::create(b, val, delta); }; Value ix, iy, dx, dy; @@ -1887,54 +1903,55 @@ class GenericResizeConverter : public OpRewritePattern { } if (op.getMode() == "NEAREST_NEIGHBOR") { - auto one = b.create(b.getI32IntegerAttr(1)); + auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, Value max, int size, ImplicitLocOpBuilder &b) -> Value { if (size == 1) { - return b.create(0); + return arith::ConstantIndexOp::create(b, 0); } Value pred; if (floatingPointMode) { - auto h = b.create(b.getFloatAttr(floatTy, 0.5f)); - pred = b.create(arith::CmpFPredicate::OGE, dval, h); + auto h = + arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f)); + pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h); } else { - Value dvalDouble = b.create(dval, one); - pred = b.create(arith::CmpIPredicate::sge, - dvalDouble, scale); + Value dvalDouble = arith::ShLIOp::create(b, dval, one); + pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge, + dvalDouble, scale); } - auto offset = b.create(pred, one, zeroI32); - val = b.create(val, offset); + auto offset = arith::SelectOp::create(b, pred, one, zeroI32); + val = arith::AddIOp::create(b, val, offset); val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false); - return b.create(b.getIndexType(), val); + return arith::IndexCastOp::create(b, b.getIndexType(), val); }; iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); - Value result = b.create( - input, ValueRange{batch, iy, ix, channel}); + Value result = tensor::ExtractOp::create( + b, input, ValueRange{batch, iy, ix, channel}); - b.create(result); + linalg::YieldOp::create(b, result); } else { // The mode here must be BILINEAR. assert(op.getMode() == "BILINEAR"); - auto oneVal = b.create(b.getI32IntegerAttr(1)); + auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, Value max, ImplicitLocOpBuilder &b) { val0 = in; - val1 = b.create(val0, oneVal); + val1 = arith::AddIOp::create(b, val0, oneVal); val0 = clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false); val1 = clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false); - val0 = b.create(b.getIndexType(), val0); - val1 = b.create(b.getIndexType(), val1); + val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0); + val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1); }; // Linalg equivalent to the section below: @@ -1946,27 +1963,27 @@ class GenericResizeConverter : public OpRewritePattern { getClampedIdxs(y0, y1, imageH, iy, hMax, b); getClampedIdxs(x0, x1, imageW, ix, wMax, b); - Value y0x0 = b.create( - input, ValueRange{batch, y0, x0, channel}); - Value y0x1 = b.create( - input, ValueRange{batch, y0, x1, channel}); - Value y1x0 = b.create( - input, ValueRange{batch, y1, x0, channel}); - Value y1x1 = b.create( - input, ValueRange{batch, y1, x1, channel}); + Value y0x0 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y0, x0, channel}); + Value y0x1 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y0, x1, channel}); + Value y1x0 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y1, x0, channel}); + Value y1x1 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = - b.create(b.getFloatAttr(floatTy, 1.0f)); + arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return val0; - Value oneMinusDelta = b.create(oneVal, delta); - Value mul0 = b.create(val0, oneMinusDelta); - Value mul1 = b.create(val1, delta); - return b.create(mul0, mul1); + Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta); + Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta); + Value mul1 = arith::MulFOp::create(b, val1, delta); + return arith::AddFOp::create(b, mul0, mul1); }; // Linalg equivalent to the section below: @@ -1982,18 +1999,18 @@ class GenericResizeConverter : public OpRewritePattern { // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); - b.create(result); + linalg::YieldOp::create(b, result); } else { // Perform in quantized space. - y0x0 = b.create(resultETy, y0x0); - y0x1 = b.create(resultETy, y0x1); - y1x0 = b.create(resultETy, y1x0); - y1x1 = b.create(resultETy, y1x1); + y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0); + y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1); + y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0); + y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1); const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { - dx = b.create(resultETy, dx); - dy = b.create(resultETy, dy); + dx = arith::ExtSIOp::create(b, resultETy, dx); + dy = arith::ExtSIOp::create(b, resultETy, dy); } Value yScaleNExt = yScaleN; @@ -2002,26 +2019,26 @@ class GenericResizeConverter : public OpRewritePattern { const int64_t scaleBitwidth = xScaleN.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { - yScaleNExt = b.create(resultETy, yScaleN); - xScaleNExt = b.create(resultETy, xScaleN); + yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN); + xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN); } auto interpolate = [](Value val0, Value val1, Value weight1, Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) - return b.create(val0, scale); - Value weight0 = b.create(scale, weight1); - Value mul0 = b.create(val0, weight0); - Value mul1 = b.create(val1, weight1); - return b.create(mul0, mul1); + return arith::MulIOp::create(b, val0, scale); + Value weight0 = arith::SubIOp::create(b, scale, weight1); + Value mul0 = arith::MulIOp::create(b, val0, weight0); + Value mul1 = arith::MulIOp::create(b, val1, weight1); + return arith::AddIOp::create(b, mul0, mul1); }; Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); Value result = interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); - b.create(result); + linalg::YieldOp::create(b, result); } } } @@ -2072,11 +2089,11 @@ class ReverseConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - Value axisDimSize = rewriter.create(loc, input, axis); + Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. auto emptyTensor = rewriter @@ -2094,22 +2111,22 @@ class ReverseConverter : public OpRewritePattern { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { Value index = - rewriter.create(nestedLoc, i).getResult(); + linalg::IndexOp::create(rewriter, nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create(nestedLoc, 1); + auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1); auto sizeMinusOne = - rewriter.create(nestedLoc, axisDimSize, one); - index = rewriter.create(nestedLoc, sizeMinusOne, - index); + arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one); + index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne, + index); } indices.push_back(index); } - auto extract = nestedBuilder.create( - nestedLoc, input, indices); - nestedBuilder.create(op.getLoc(), - extract.getResult()); + auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc, + input, indices); + linalg::YieldOp::create(nestedBuilder, op.getLoc(), + extract.getResult()); }); return success(); } @@ -2148,12 +2165,12 @@ struct TileConverter : public OpConversionPattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) || multiples[i] == -1) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - auto emptyTensor = rewriter.create( - op.getLoc(), genericShape, elementTy, dynDims); + auto emptyTensor = tensor::EmptyOp::create( + rewriter, op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; @@ -2168,12 +2185,12 @@ struct TileConverter : public OpConversionPattern { SmallVector affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; - auto genericOp = rewriter.create( - loc, RankedTensorType::get(genericShape, elementTy), input, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(op.getLoc(), *args.begin()); + linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin()); }); auto shapeValue = getTosaConstShape( @@ -2220,7 +2237,7 @@ class ArgMaxConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -2229,8 +2246,8 @@ class ArgMaxConverter : public OpRewritePattern { .create(loc, resultTy.getShape(), outElementTy, dynDims) .getResult(); - auto fillValueIdx = rewriter.create( - loc, rewriter.getIntegerAttr(outElementTy, 0)); + auto fillValueIdx = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter .create(loc, ValueRange{fillValueIdx}, @@ -2250,7 +2267,7 @@ class ArgMaxConverter : public OpRewritePattern { argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = - rewriter.create(loc, fillValueMaxAttr); + arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = rewriter .create(loc, ValueRange{fillValueMax}, @@ -2274,8 +2291,8 @@ class ArgMaxConverter : public OpRewritePattern { bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); - auto linalgOp = rewriter.create( - loc, ArrayRef({resultTy, resultMaxTy}), input, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -2283,42 +2300,46 @@ class ArgMaxConverter : public OpRewritePattern { auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create( - nestedLoc, oldIndex.getType(), - rewriter.create(loc, axis)); + Value newIndex = arith::IndexCastOp::create( + rewriter, nestedLoc, oldIndex.getType(), + linalg::IndexOp::create(rewriter, loc, axis)); Value predicate; if (isa(inElementTy)) { if (argmaxOp.getNanMode() == "IGNORE") { // Only update index & max value for non NaN values. If all // values are NaNs, the initial index will be return which is 0. - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + predicate = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::OGT, + newValue, oldValue); } else { // Update max value if either of the following is true: // - new value is bigger // - cur max is not NaN and new value is NaN - Value gt = rewriter.create( - nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue); - Value oldNonNaN = rewriter.create( - nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue); - predicate = rewriter.create( - nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); + Value gt = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::UGT, + newValue, oldValue); + Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::ORD, + oldValue, oldValue); + predicate = arith::AndIOp::create( + rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); } } else if (isa(inElementTy)) { - predicate = rewriter.create( - nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); + predicate = arith::CmpIOp::create(rewriter, nestedLoc, + arith::CmpIPredicate::sgt, + newValue, oldValue); } else { didEncounterError = true; return; } - auto resultMax = rewriter.create( - nestedLoc, predicate, newValue, oldValue); - auto resultIndex = rewriter.create( - nestedLoc, predicate, newIndex, oldIndex); - nestedBuilder.create( - nestedLoc, ValueRange({resultIndex, resultMax})); + auto resultMax = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newValue, oldValue); + auto resultIndex = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newIndex, oldIndex); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + ValueRange({resultIndex, resultMax})); }); if (didEncounterError) @@ -2363,19 +2384,19 @@ class GatherConverter : public OpConversionPattern { rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create( - loc, ArrayRef({resultTy}), ValueRange{indices}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef({resultTy}), ValueRange{indices}, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; - auto index0 = rewriter.create(loc, 0); - Value index1 = rewriter.create( - loc, rewriter.getIndexType(), indexValue); - auto index2 = rewriter.create(loc, 2); - Value extract = rewriter.create( - loc, input, ValueRange{index0, index1, index2}); - rewriter.create(loc, extract); + auto index0 = linalg::IndexOp::create(rewriter, loc, 0); + Value index1 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), indexValue); + auto index2 = linalg::IndexOp::create(rewriter, loc, 2); + Value extract = tensor::ExtractOp::create( + rewriter, loc, input, ValueRange{index0, index1, index2}); + linalg::YieldOp::create(rewriter, loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); @@ -2424,7 +2445,7 @@ class TableConverter : public OpRewritePattern { for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { dynDims.push_back( - rewriter.create(loc, op.getOperand(0), i)); + tensor::DimOp::create(rewriter, loc, op.getOperand(0), i)); } } @@ -2437,9 +2458,9 @@ class TableConverter : public OpRewritePattern { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create( - loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, - getNParallelLoopsAttrs(resultTy.getRank())); + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, + affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { @@ -2452,69 +2473,69 @@ class TableConverter : public OpRewritePattern { rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create( - loc, rewriter.getIndexType(), inputValue); - Value offset = rewriter.create(loc, 128); - index = rewriter.create(loc, rewriter.getIndexType(), - index, offset); + Value index = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), inputValue); + Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128); + index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), + index, offset); Value extract = - rewriter.create(loc, table, ValueRange{index}); - rewriter.create(loc, extract); + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + linalg::YieldOp::create(rewriter, loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create( - loc, rewriter.getI32Type(), inputValue); - - auto offset = rewriter.create( - loc, rewriter.getI32IntegerAttr(32768)); - auto seven = rewriter.create( - loc, rewriter.getI32IntegerAttr(7)); - auto one = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = rewriter.create( - loc, rewriter.getI32IntegerAttr(127)); + Value extend = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), inputValue); + + auto offset = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(32768)); + auto seven = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(7)); + auto one = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(1)); + auto b1111111 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create(loc, extend, offset); - Value index = rewriter.create(loc, extendAdd, seven); + auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset); + Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven); Value fraction = - rewriter.create(loc, extendAdd, b1111111); + arith::AndIOp::create(rewriter, loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create(loc, index, one); + Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one); - index = rewriter.create( - loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create( - loc, rewriter.getIndexType(), indexPlusOne); + index = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), index); + indexPlusOne = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), indexPlusOne); Value base = - rewriter.create(loc, table, ValueRange{index}); - Value next = rewriter.create( - loc, table, ValueRange{indexPlusOne}); + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + Value next = tensor::ExtractOp::create(rewriter, loc, table, + ValueRange{indexPlusOne}); base = - rewriter.create(loc, rewriter.getI32Type(), base); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base); next = - rewriter.create(loc, rewriter.getI32Type(), next); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create(loc, base, seven); - Value diff = rewriter.create(loc, next, base); - Value diffScaled = rewriter.create(loc, diff, fraction); + Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven); + Value diff = arith::SubIOp::create(rewriter, loc, next, base); + Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction); Value result = - rewriter.create(loc, baseScaled, diffScaled); + arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled); - rewriter.create(loc, result); + linalg::YieldOp::create(rewriter, loc, result); return success(); } @@ -2532,8 +2553,8 @@ struct RFFT2dConverter final : public OpRewritePattern { static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, OpFoldResult ofr) { - auto one = builder.create(loc, 1); - auto two = builder.create(loc, 2); + auto one = arith::ConstantIndexOp::create(builder, loc, 1); + auto two = arith::ConstantIndexOp::create(builder, loc, 2); auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); auto divBy2 = builder.createOrFold(loc, value, two); @@ -2562,9 +2583,9 @@ struct RFFT2dConverter final : public OpRewritePattern { RankedTensorType type, llvm::ArrayRef dynamicSizes) { auto emptyTensor = - rewriter.create(loc, type, dynamicSizes); + tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) @@ -2574,18 +2595,18 @@ struct RFFT2dConverter final : public OpRewritePattern { static Value castIndexToFloat(OpBuilder &builder, Location loc, FloatType type, Value value) { - auto integerVal = builder.create( - loc, + auto integerVal = arith::IndexCastUIOp::create( + builder, loc, type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() : builder.getI32Type(), value); - return builder.create(loc, type, integerVal); + return arith::UIToFPOp::create(builder, loc, type, integerVal); } static Value createLinalgIndex(OpBuilder &builder, Location loc, FloatType type, int64_t index) { - auto indexVal = builder.create(loc, index); + auto indexVal = linalg::IndexOp::create(builder, loc, index); return castIndexToFloat(builder, loc, type, indexVal); } @@ -2640,7 +2661,7 @@ struct RFFT2dConverter final : public OpRewritePattern { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); - auto twoPi = rewriter.create(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); @@ -2650,43 +2671,45 @@ struct RFFT2dConverter final : public OpRewritePattern { Value sumImag = args[2]; // Indices for angle computation - Value oy = builder.create(loc, 1); - Value ox = builder.create(loc, 2); - Value iy = builder.create(loc, 3); - Value ix = builder.create(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // Calculating angle without integer parts of components as sin/cos are // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) // / W); - auto iyXoy = builder.create(loc, iy, oy); - auto ixXox = builder.create(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create(loc, iyXoy, dimH); - auto ixRem = builder.create(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); - auto yComponent = builder.create(loc, iyRemFloat, constH); - auto xComponent = builder.create(loc, ixRemFloat, constW); - auto sumXY = builder.create(loc, yComponent, xComponent); - auto angle = builder.create(loc, twoPi, sumXY); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); // realComponent = valReal * cos(angle) // imagComponent = valReal * sin(angle) - auto cosAngle = builder.create(loc, angle); - auto sinAngle = builder.create(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); auto realComponent = - builder.create(loc, valReal, cosAngle); + arith::MulFOp::create(builder, loc, valReal, cosAngle); auto imagComponent = - builder.create(loc, valReal, sinAngle); + arith::MulFOp::create(builder, loc, valReal, sinAngle); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create(loc, sumReal, realComponent); - auto outImag = builder.create(loc, sumImag, imagComponent); + auto outReal = + arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = + arith::SubFOp::create(builder, loc, sumImag, imagComponent); - builder.create(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( @@ -2760,7 +2783,7 @@ struct FFT2dConverter final : OpRewritePattern { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); - auto twoPi = rewriter.create(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); Value constH = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); Value constW = @@ -2773,57 +2796,59 @@ struct FFT2dConverter final : OpRewritePattern { Value sumImag = args[3]; // Indices for angle computation - Value oy = builder.create(loc, 1); - Value ox = builder.create(loc, 2); - Value iy = builder.create(loc, 3); - Value ix = builder.create(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * // ox) % W ) / W); - auto iyXoy = builder.create(loc, iy, oy); - auto ixXox = builder.create(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create(loc, iyXoy, dimH); - auto ixRem = builder.create(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); auto ixRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); - auto yComponent = builder.create(loc, iyRemFloat, constH); - auto xComponent = builder.create(loc, ixRemFloat, constW); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); - auto sumXY = builder.create(loc, yComponent, xComponent); - auto angle = builder.create(loc, twoPi, sumXY); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); if (inverse.getValue()) { - angle = builder.create( - loc, angle, - rewriter.create( - loc, rewriter.getFloatAttr(real_el_ty, -1.0))); + angle = arith::MulFOp::create( + builder, loc, angle, + arith::ConstantOp::create(rewriter, loc, + rewriter.getFloatAttr(real_el_ty, -1.0))); } // realComponent = val_real * cos(a) + val_imag * sin(a); // imagComponent = -val_real * sin(a) + val_imag * cos(a); - auto cosAngle = builder.create(loc, angle); - auto sinAngle = builder.create(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); - auto rcos = builder.create(loc, valReal, cosAngle); - auto rsin = builder.create(loc, valImag, sinAngle); - auto realComponent = builder.create(loc, rcos, rsin); + auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle); + auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle); + auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin); - auto icos = builder.create(loc, valImag, cosAngle); - auto isin = builder.create(loc, valReal, sinAngle); + auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle); + auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle); - auto imagComponent = builder.create(loc, icos, isin); + auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create(loc, sumReal, realComponent); - auto outImag = builder.create(loc, sumImag, imagComponent); + auto outReal = + arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = + arith::AddFOp::create(builder, loc, sumImag, imagComponent); - builder.create(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 00b9a065dfb3d..3a205246ddd9e 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -52,11 +52,11 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create(loc, padAttr); + Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr); - return rewriter.create( - loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices, - highIndices, padValue); + return tensor::PadOp::create(rewriter, loc, + RankedTensorType::get(paddedShape, inputETy), + input, lowIndices, highIndices, padValue); } static mlir::Value @@ -72,10 +72,10 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value biasVal = args[0]; Type resType = args[1].getType(); if (resType != biasVal.getType()) { - biasVal = builder.create(loc, resType, biasVal); + biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); } - Value added = builder.create(loc, biasVal, args[1]); - builder.create(loc, added); + Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); }) .getResult(0); } @@ -134,19 +134,19 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, if (resType != biasVal.getType()) { biasVal = resultTy.getElementType().isFloat() - ? builder.create(loc, resType, biasVal) + ? arith::ExtFOp::create(builder, loc, resType, biasVal) .getResult() - : builder.create(loc, resType, biasVal) + : arith::ExtSIOp::create(builder, loc, resType, biasVal) .getResult(); } - builder.create(loc, biasVal); + linalg::YieldOp::create(builder, loc, biasVal); }) .getResult(0); } static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder) { - return builder.create(attr); + return arith::ConstantIndexOp::create(builder, attr); } // Calculating the output width/height using the formula: @@ -160,22 +160,22 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t dilationAttr, OpBuilder &rewriter) { ImplicitLocOpBuilder builder(loc, rewriter); - auto one = rewriter.create( - loc, IntegerAttr::get(inputDim.getType(), 1)); + auto one = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(inputDim.getType(), 1)); Value padBefore = reifyConstantDim(padBeforeAttr, builder); - Value paddedBefore = builder.create(inputDim, padBefore); + Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore); Value padAfter = reifyConstantDim(padAfterAttr, builder); - Value paddedAfter = builder.create(paddedBefore, padAfter); + Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter); - Value subOne = builder.create(kernelDim, one); + Value subOne = arith::SubIOp::create(builder, kernelDim, one); Value dilation = reifyConstantDim(dilationAttr, builder); - Value dilated = builder.create(dilation, subOne); - Value addOne = builder.create(dilated, one); + Value dilated = arith::MulIOp::create(builder, dilation, subOne); + Value addOne = arith::AddIOp::create(builder, dilated, one); - Value subtract = builder.create(paddedAfter, addOne); + Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne); Value stride = reifyConstantDim(strideAttr, builder); - Value divide = builder.create(subtract, stride); - return builder.create(divide, one); + Value divide = arith::DivUIOp::create(builder, subtract, stride); + return arith::AddIOp::create(builder, divide, one); } // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D @@ -198,9 +198,9 @@ static SmallVector inferDynamicDimsForConv( auto padBottom = padAttr[i * 2 + 1]; auto stride = strideAttr[i]; auto dilation = dilationAttr[i]; - Value initDynDim = rewriter.create(loc, input, inputDim); + Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim); Value kernelDynDim = - rewriter.create(loc, weight, kernelDim); + tensor::DimOp::create(rewriter, loc, weight, kernelDim); // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) dynDims[inputDim] = getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom, @@ -211,7 +211,7 @@ static SmallVector inferDynamicDimsForConv( // Get the batch/channels dimensions. for (int i = 0; i < inputRank; i++) { if (resultTy.isDynamicDim(i) && !dynDims[i]) - dynDims[i] = rewriter.create(loc, input, i); + dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i); } SmallVector filteredDims = condenseValues(dynDims); @@ -350,8 +350,8 @@ class ConvConverter : public OpConversionPattern { auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, - weightPermAttr); + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, + weightPermAttr); } } @@ -372,8 +372,8 @@ class ConvConverter : public OpConversionPattern { auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, - weightPermAttr); + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, + weightPermAttr); } // Extract the attributes for convolution. @@ -384,8 +384,8 @@ class ConvConverter : public OpConversionPattern { auto strideAttr = rewriter.getI64TensorAttr(stride); auto dilationAttr = rewriter.getI64TensorAttr(dilation); - Value biasEmptyTensor = rewriter.create( - loc, resultTy.getShape(), accETy, filteredDims); + Value biasEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), accETy, filteredDims); Value broadcastBias = linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor); @@ -394,8 +394,8 @@ class ConvConverter : public OpConversionPattern { auto iZp = rewriter.getI32IntegerAttr(inputZpVal); auto kZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); Value conv = rewriter @@ -417,7 +417,7 @@ class ConvConverter : public OpConversionPattern { // We may need to truncate back to the result type if the accumulator was // wider than the result. if (resultTy != accTy) - conv = rewriter.create(loc, resultTy, conv); + conv = tosa::CastOp::create(rewriter, loc, resultTy, conv); rewriter.replaceOp(op, conv); return success(); @@ -526,16 +526,16 @@ class DepthwiseConvConverter accETy); auto resultZeroAttr = rewriter.getZeroAttr(accETy); - Value emptyTensor = rewriter.create( - loc, linalgConvTy.getShape(), accETy, filteredDims); - Value zero = rewriter.create(loc, resultZeroAttr); + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); + Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, ValueRange{emptyTensor}) .result(); - Value biasEmptyTensor = rewriter.create( - loc, resultTy.getShape(), resultETy, filteredDims); + Value biasEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), resultETy, filteredDims); // Broadcast the initial value to the output tensor before convolving. SmallVector indexingMaps; @@ -553,16 +553,16 @@ class DepthwiseConvConverter // We may need to truncate back to the result type if the accumulator was // wider than the result. if (accETy != resultETy) - conv = rewriter.create( - loc, + conv = tosa::CastOp::create( + rewriter, loc, RankedTensorType::get(cast(conv.getType()).getShape(), resultETy), conv); SmallVector reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create( - loc, resultTy, conv, reassociationMap); + Value convReshape = tensor::CollapseShapeOp::create( + rewriter, loc, resultTy, conv, reassociationMap); Value result = rewriter @@ -574,20 +574,20 @@ class DepthwiseConvConverter ValueRange args) { Value added; if (llvm::isa(inputETy)) - added = nestedBuilder.create(loc, args[0], - args[1]); + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); else - added = nestedBuilder.create(loc, args[0], - args[1]); - nestedBuilder.create(nestedLoc, added); + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); }) .getResult(0); rewriter.replaceOp(op, result); } else { IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal); IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, wZp); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); Value conv = rewriter .create( @@ -596,8 +596,8 @@ class DepthwiseConvConverter .getResult(0); SmallVector reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create( - loc, resultTy, conv, reassociationMap); + Value convReshape = tensor::CollapseShapeOp::create( + rewriter, loc, resultTy, conv, reassociationMap); Value result = linalgIntBroadcastExtSIAdd( rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); @@ -621,23 +621,24 @@ class MatMulConverter : public OpConversionPattern { dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) { - dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); + dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0); } if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) { - dynDims[1] = rewriter.create(loc, op->getOperand(0), 1); + dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1); } if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) { - dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); + dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2); } SmallVector filteredDims = condenseValues(dynDims); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create(loc, zeroAttr); - auto emptyTensor = rewriter.create( - loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); + Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), + outputTy.getElementType(), filteredDims); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, ValueRange{emptyTensor}) @@ -670,10 +671,10 @@ class MatMulConverter : public OpConversionPattern { return success(); } - auto aZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(aZpVal)); - auto bZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(bZpVal)); + auto aZp = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(aZpVal)); + auto bZp = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(bZpVal)); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); @@ -702,7 +703,7 @@ class MaxPool2dConverter : public OpConversionPattern { // Batch dimension if (resultTy.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); // Height/width dimensions for (int64_t dim : {1, 2}) { @@ -713,10 +714,10 @@ class MaxPool2dConverter : public OpConversionPattern { int64_t index = dim - 1; // Input height/width - Value ihw = rewriter.create(loc, input, dim); + Value ihw = tensor::DimOp::create(rewriter, loc, input, dim); // Kernel height/width - Value khw = rewriter.create(loc, kernel[index]); + Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]); // Output height/width Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2], @@ -727,7 +728,7 @@ class MaxPool2dConverter : public OpConversionPattern { // Channel dimension if (resultTy.isDynamicDim(3)) - dynamicDims.push_back(rewriter.create(loc, input, 3)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3)); return dynamicDims; } @@ -776,7 +777,7 @@ class MaxPool2dConverter : public OpConversionPattern { Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef kernel = op.getKernel(); ArrayRef stride = op.getStride(); @@ -785,15 +786,16 @@ class MaxPool2dConverter : public OpConversionPattern { Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value emptyTensor = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); + Value emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultTy.getElementType(), dynamicDims); Value filledEmptyTensor = - rewriter.create(loc, initialValue, emptyTensor) + linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor) .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, resultETy); + tensor::EmptyOp::create(rewriter, loc, kernel, resultETy); if (isUnsigned) { rewriter.replaceOpWithNewOp( @@ -802,8 +804,8 @@ class MaxPool2dConverter : public OpConversionPattern { return llvm::success(); } - auto resultOp = rewriter.create( - op->getLoc(), ArrayRef{resultTy}, + auto resultOp = linalg::PoolingNhwcMaxOp::create( + rewriter, op->getLoc(), ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); @@ -823,9 +825,10 @@ class MaxPool2dConverter : public OpConversionPattern { // it to include the appropriate checks. If the current value is NaN the // old value of pool will be taken otherwise we use the result. if (nanMode == "IGNORE") { - auto genericOp = rewriter.create( - loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(), - resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(), + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultOp.getType(0), resultOp.getInputs(), + resultOp.getOutputs(), resultOp.getIndexingMapsArray(), + resultOp.getIteratorTypesArray(), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { IRMapping map; auto oldBlock = resultOp.getRegion().begin(); @@ -833,12 +836,12 @@ class MaxPool2dConverter : public OpConversionPattern { auto &oldMaxOp = *resultOp.getBlock()->begin(); map.map(oldArgs, blockArgs); auto *newOp = opBuilder.clone(oldMaxOp, map); - Value isNaN = opBuilder.create( - loc, arith::CmpFPredicate::UNO, blockArgs.front(), - blockArgs.front()); - auto selectOp = opBuilder.create( - loc, isNaN, blockArgs.back(), newOp->getResult(0)); - opBuilder.create(loc, selectOp.getResult()); + Value isNaN = + arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO, + blockArgs.front(), blockArgs.front()); + auto selectOp = arith::SelectOp::create( + opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0)); + linalg::YieldOp::create(opBuilder, loc, selectOp.getResult()); }); rewriter.replaceOp(resultOp, genericOp); } @@ -894,7 +897,7 @@ class AvgPool2dConverter : public OpRewritePattern { Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); auto initialAttr = rewriter.getZeroAttr(accETy); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef kernel = op.getKernel(); ArrayRef stride = op.getStride(); @@ -903,8 +906,8 @@ class AvgPool2dConverter : public OpRewritePattern { Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value poolEmptyTensor = rewriter.create( - loc, accTy.getShape(), accETy, dynamicDims); + Value poolEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = rewriter @@ -913,7 +916,7 @@ class AvgPool2dConverter : public OpRewritePattern { .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, accETy); + tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. Value poolingOp = rewriter @@ -925,24 +928,24 @@ class AvgPool2dConverter : public OpRewritePattern { // Normalize the summed value by the number of elements grouped in each // pool. - Value iH = rewriter.create(loc, poolingOp, 1); - Value iW = rewriter.create(loc, poolingOp, 2); + Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1); + Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2); - auto one = rewriter.create(loc, 1); - iH = rewriter.create(loc, iH, one); - iW = rewriter.create(loc, iW, one); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); + iH = arith::SubIOp::create(rewriter, loc, iH, one); + iW = arith::SubIOp::create(rewriter, loc, iW, one); - Value genericEmptyTensor = rewriter.create( - loc, resultTy.getShape(), resultETy, dynamicDims); + Value genericEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), resultETy, dynamicDims); auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - auto genericOp = rewriter.create( - loc, ArrayRef({resultTy}), ValueRange{poolingOp}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef({resultTy}), ValueRange{poolingOp}, ValueRange{genericEmptyTensor}, ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // Determines what the portion of valid input is covered by the // kernel. @@ -950,30 +953,30 @@ class AvgPool2dConverter : public OpRewritePattern { if (pad == 0) return valid; - auto padVal = rewriter.create(loc, pad); - Value dpos = rewriter.create(loc, pos, padVal); + auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad); + Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal); - Value offset = rewriter.create(loc, dpos, zero); - return rewriter.create(loc, valid, offset) + Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero); + return arith::AddIOp::create(rewriter, loc, valid, offset) ->getResult(0); }; auto coverageFn = [&](int64_t i, Value isize) -> Value { Value strideVal = - rewriter.create(loc, stride[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]); Value val = - rewriter.create(loc, kernel[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]); // Find the position relative to the input tensor's ends. - Value left = rewriter.create(loc, i); - Value right = rewriter.create(loc, isize, left); - left = rewriter.create(loc, left, strideVal); - right = rewriter.create(loc, right, strideVal); + Value left = linalg::IndexOp::create(rewriter, loc, i); + Value right = arith::SubIOp::create(rewriter, loc, isize, left); + left = arith::MulIOp::create(rewriter, loc, left, strideVal); + right = arith::MulIOp::create(rewriter, loc, right, strideVal); // Determine how much padding was included. val = padFn(val, left, pad[i * 2]); val = padFn(val, right, pad[i * 2 + 1]); - return rewriter.create(loc, one, val); + return arith::MaxSIOp::create(rewriter, loc, one, val); }; // Compute the indices from either end. @@ -981,70 +984,70 @@ class AvgPool2dConverter : public OpRewritePattern { Value kW3 = coverageFn(2, iW); // Compute the total number of elements and normalize. - auto count = rewriter.create( - loc, rewriter.getI32Type(), - rewriter.create(loc, kH3, kW3)); + auto count = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), + arith::MulIOp::create(rewriter, loc, kH3, kW3)); // Divide by the number of summed values. For floats this is just // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; if (isa(accETy)) { - auto countF = rewriter.create(loc, accETy, count); - poolVal = rewriter.create(loc, poolVal, countF) + auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count); + poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF) ->getResult(0); if (accETy.getIntOrFloatBitWidth() > resultETy.getIntOrFloatBitWidth()) poolVal = - rewriter.create(loc, resultETy, poolVal); + arith::TruncFOp::create(rewriter, loc, resultETy, poolVal); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (inputZpVal != 0) { - auto inputZp = rewriter.create( - loc, b.getIntegerAttr(accETy, inputZpVal)); + auto inputZp = arith::ConstantOp::create( + rewriter, loc, b.getIntegerAttr(accETy, inputZpVal)); Value offset = - rewriter.create(loc, accETy, count, inputZp); + arith::MulIOp::create(rewriter, loc, accETy, count, inputZp); poolVal = - rewriter.create(loc, accETy, poolVal, offset); + arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset); } // Compute: k = 32 - count_leading_zeros(value - 1) - Value one32 = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - Value thirtyTwo32 = rewriter.create( - loc, rewriter.getI32IntegerAttr(32)); + Value one32 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(1)); + Value thirtyTwo32 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(32)); Value countSubOne = - rewriter.create(loc, count, one32); + arith::SubIOp::create(rewriter, loc, count, one32); Value leadingZeros = - rewriter.create(loc, countSubOne); + math::CountLeadingZerosOp::create(rewriter, loc, countSubOne); Value k = - rewriter.create(loc, thirtyTwo32, leadingZeros); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros); // Compute: numerator = ((1 << 30) + 1) << k Value k64 = - rewriter.create(loc, rewriter.getI64Type(), k); - Value thirtyShiftPlusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); + arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k); + Value thirtyShiftPlusOne = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); Value numerator = - rewriter.create(loc, thirtyShiftPlusOne, k64); + arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64); // Compute: scale.multiplier = numerator / value; - Value count64 = rewriter.create( - loc, rewriter.getI64Type(), count); + Value count64 = arith::ExtUIOp::create( + rewriter, loc, rewriter.getI64Type(), count); Value multiplier = - rewriter.create(loc, numerator, count64); - multiplier = rewriter.create( - loc, rewriter.getI32Type(), multiplier); + arith::DivUIOp::create(rewriter, loc, numerator, count64); + multiplier = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), multiplier); // Compute: scale.shift = 30 + k Value k8 = - rewriter.create(loc, rewriter.getI8Type(), k); - Value thirty8 = rewriter.create( - loc, rewriter.getI8IntegerAttr(30)); - Value shift = rewriter.create(loc, k8, thirty8); + arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k); + Value thirty8 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI8IntegerAttr(30)); + Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = rewriter @@ -1056,20 +1059,21 @@ class AvgPool2dConverter : public OpRewritePattern { // If we have quantization information we need to apply output // zeropoint. if (outputZpVal != 0) { - auto outputZp = rewriter.create( - loc, b.getIntegerAttr(scaled.getType(), outputZpVal)); - scaled = rewriter.create(loc, scaled, outputZp) + auto outputZp = arith::ConstantOp::create( + rewriter, loc, + b.getIntegerAttr(scaled.getType(), outputZpVal)); + scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp) .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create( - loc, accETy, + auto min = arith::ConstantIntOp::create( + rewriter, loc, accETy, APInt::getSignedMinValue(outBitwidth).getSExtValue()); - auto max = rewriter.create( - loc, accETy, + auto max = arith::ConstantIntOp::create( + rewriter, loc, accETy, APInt::getSignedMaxValue(outBitwidth).getSExtValue()); auto clamp = clampIntHelper(loc, scaled, min, max, rewriter, /*isUnsigned=*/false); @@ -1078,11 +1082,11 @@ class AvgPool2dConverter : public OpRewritePattern { // Convert type. if (resultETy != clamp.getType()) { poolVal = - rewriter.create(loc, resultETy, poolVal); + arith::TruncIOp::create(rewriter, loc, resultETy, poolVal); } } - rewriter.create(loc, poolVal); + linalg::YieldOp::create(rewriter, loc, poolVal); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1107,8 +1111,9 @@ class TransposeConverter : public OpRewritePattern { auto permutedSizes = applyTOSAPermutation(inputSizes, constantPerms); - auto permutedInit = rewriter.create( - loc, permutedSizes, op.getInput1().getType().getElementType()); + auto permutedInit = + tensor::EmptyOp::create(rewriter, loc, permutedSizes, + op.getInput1().getType().getElementType()); rewriter.replaceOpWithNewOp( op, op.getInput1(), permutedInit, llvm::to_vector(llvm::map_range( diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index 7dbccd19a0518..b83f5ec9b0283 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -27,8 +27,8 @@ class VariableOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::VariableOp op, PatternRewriter &rewriter) const final { auto variableType = tosa::getVariableType(op); - auto newVariable = rewriter.create( - op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, + auto newVariable = mlir::ml_program::GlobalOp::create( + rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, op.getInitialValueAttr(), /*sym_visibility=*/nullptr); newVariable.setPrivate(); rewriter.replaceOp(op, newVariable); @@ -45,8 +45,8 @@ class VariableWriteOpConverter PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableWrite = rewriter.create( - op.getLoc(), globalSymbolRef, op.getInput1()); + auto newVariableWrite = ml_program::GlobalStoreOp::create( + rewriter, op.getLoc(), globalSymbolRef, op.getInput1()); rewriter.replaceOp(op, newVariableWrite); return success(); } @@ -60,8 +60,8 @@ class VariableReadOpConverter : public OpRewritePattern { PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableRead = rewriter.create( - op.getLoc(), op.getType(), globalSymbolRef); + auto newVariableRead = ml_program::GlobalLoadOp::create( + rewriter, op.getLoc(), op.getType(), globalSymbolRef); rewriter.replaceOp(op, newVariableRead); return success(); diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 03f9d20ad69de..aa6b4164e9876 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -30,7 +30,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion, auto yield = cast(headBlock->getTerminator()); rewriter.setInsertionPoint(yield); - rewriter.create(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); rewriter.eraseOp(yield); headBlock->eraseArguments(0, headBlock->getNumArguments()); @@ -46,13 +46,13 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion, auto yield = cast(headBlock->getTerminator()); rewriter.setInsertionPoint(yield); if (isCond) { - auto condition = - rewriter.create(yield.getLoc(), yield.getOperand(0)); - rewriter.create(yield.getLoc(), condition, - headBlock->getArguments()); + auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(), + yield.getOperand(0)); + scf::ConditionOp::create(rewriter, yield.getLoc(), condition, + headBlock->getArguments()); } else { rewriter.setInsertionPoint(yield); - rewriter.create(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); } rewriter.eraseOp(yield); } @@ -66,9 +66,9 @@ class IfOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::IfOp op, PatternRewriter &rewriter) const final { auto condition = - rewriter.create(op.getLoc(), op.getCondition()); - auto newIf = rewriter.create(op.getLoc(), op.getResultTypes(), - condition, true); + tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition()); + auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(), + condition, true); inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(), rewriter); @@ -88,7 +88,7 @@ class ScatterOpConverter : public OpRewritePattern { static Value createIndexConst(OpBuilder &builder, Location loc, int64_t value) { - return builder.create(loc, value); + return arith::ConstantIndexOp::create(builder, loc, value); } public: @@ -119,9 +119,9 @@ class ScatterOpConverter : public OpRewritePattern { auto n = ivs[0]; // Read the index and cast it to index type - auto index = builder.create(loc, indices, ivs); - auto castIndex = builder.create( - loc, builder.getIndexType(), index); + auto index = tensor::ExtractOp::create(builder, loc, indices, ivs); + auto castIndex = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), index); // Offset, sizes, and strides for the input tensor auto inputOffset = llvm::to_vector(ivs); @@ -130,13 +130,13 @@ class ScatterOpConverter : public OpRewritePattern { llvm::SmallVector sizes = {one, one, dimC}; llvm::SmallVector strides = {one, one, one}; - auto slice = builder.create( - loc, input, inputOffset, sizes, strides); + auto slice = tensor::ExtractSliceOp::create(builder, loc, input, + inputOffset, sizes, strides); // Insert the slice into the output accumulator tensor. llvm::SmallVector outputOffset = {n, castIndex, zero}; - auto updated = builder.create( - loc, slice, args[0], outputOffset, sizes, strides); + auto updated = tensor::InsertSliceOp::create( + builder, loc, slice, args[0], outputOffset, sizes, strides); return {updated}; }; @@ -155,8 +155,8 @@ class WhileOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::WhileOp op, PatternRewriter &rewriter) const final { - auto newWhile = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputList()); + auto newWhile = scf::WhileOp::create( + rewriter, op.getLoc(), op.getResultTypes(), op.getInputList()); rewriter.createBlock(&newWhile.getBefore()); rewriter.createBlock(&newWhile.getAfter()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index c6cbcb0f8ab2b..2945ae3b49f1f 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -308,15 +308,15 @@ class SliceConverter : public OpConversionPattern { if (ShapedType::isStatic(sizes.back())) continue; - auto dim = rewriter.create(loc, input, index); - auto offset = rewriter.create( - loc, rewriter.getIndexAttr(sliceStarts[index])); - dynSizes.push_back(rewriter.create(loc, dim, offset)); + auto dim = tensor::DimOp::create(rewriter, loc, input, index); + auto offset = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(sliceStarts[index])); + dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset)); } - auto newSliceOp = rewriter.create( - sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, - ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), + auto newSliceOp = tensor::ExtractSliceOp::create( + rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), + dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); @@ -361,7 +361,7 @@ class PadConverter : public OpConversionPattern { Value padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), - ValueRange({rewriter.create(loc, 0)})); + ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)})); if (!padConstant) { return rewriter.notifyMatchFailure( @@ -375,16 +375,16 @@ class PadConverter : public OpConversionPattern { highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value lowVal = rewriter.create( - loc, rewriter.getIndexAttr(paddingVals[2 * i])); - Value highVal = rewriter.create( - loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); + Value lowVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i])); + Value highVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); lowValues.push_back(lowVal); highValues.push_back(highVal); } - auto newPadOp = rewriter.create( - loc, padOp.getType(), input, lowValues, highValues, padConstant); + auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input, + lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); @@ -402,7 +402,7 @@ struct ConcatConverter : public OpConversionPattern { Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis)); int64_t rank = resultType.getRank(); SmallVector strides(rank, rewriter.getIndexAttr(1)); @@ -439,8 +439,9 @@ struct ConcatConverter : public OpConversionPattern { } } - Value result = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynDims); + Value result = + tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index d6f9495b2567c..125ea1eb60ed6 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -226,22 +226,22 @@ struct BroadcastOpToArmSMELowering (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - broadcastOp1D = rewriter.create( - loc, tileSliceType, broadcastOp.getSource()); + broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType, + broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) // Value to broadcast is already a 1-d vector, nothing to do. broadcastOp1D = broadcastOp.getSource(); else return failure(); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to broadcast the value // to each tile slice. - auto nextTile = b.create( - loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; @@ -292,15 +292,15 @@ struct SplatOpToArmSMELowering : public OpRewritePattern { // First, broadcast the scalar to a 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - Value broadcastOp1D = rewriter.create( - loc, tileSliceType, splatOp.getInput()); + Value broadcastOp1D = vector::BroadcastOp::create( + rewriter, loc, tileSliceType, splatOp.getInput()); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { - auto nextTile = b.create( - loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; @@ -370,22 +370,22 @@ struct TransposeOpToArmSMELowering // Allocate buffer to store input tile to. Value vscale = - rewriter.create(loc, rewriter.getIndexType()); - Value minTileSlices = rewriter.create( - loc, rewriter.getIndexAttr(tileType.getDimSize(0))); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + Value minTileSlices = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0))); Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); Value numTileSlices = - rewriter.create(loc, vscale, minTileSlices); + arith::MulIOp::create(rewriter, loc, vscale, minTileSlices); auto bufferType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, tileType.getElementType()); - auto buffer = rewriter.create( - loc, bufferType, ValueRange{numTileSlices, numTileSlices}); + auto buffer = memref::AllocaOp::create( + rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices}); // Store input tile. - auto tileStoreOp = rewriter.create( - loc, input, buffer, ValueRange{c0, c0}); + auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input, + buffer, ValueRange{c0, c0}); // Reload input tile vertically. rewriter.replaceOpWithNewOp( @@ -488,10 +488,10 @@ struct VectorOuterProductToArmSMELowering Value rhsMaskDim = createMaskOp.getOperand(1); VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); - Value lhsMask = - rewriter.create(loc, operandMaskType, lhsMaskDim); - Value rhsMask = - rewriter.create(loc, operandMaskType, rhsMaskDim); + Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType, + lhsMaskDim); + Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType, + rhsMaskDim); return std::make_pair(lhsMask, rhsMask); } @@ -531,8 +531,8 @@ struct VectorExtractToArmSMELowering } Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); - auto extractTileSlice = rewriter.create( - loc, sourceVector, sliceIndex); + auto extractTileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, sourceVector, sliceIndex); if (position.size() == 1) { // Single index case: Extracts a 1D slice. @@ -593,10 +593,10 @@ struct VectorInsertToArmSMELowering if (position.size() == 2) { // Two indices case: Insert single element into tile. // We need to first extract the existing slice and update the element. - tileSlice = rewriter.create( - loc, insertOp.getDest(), sliceIndex); - tileSlice = rewriter.create(loc, source, tileSlice, - position[1]); + tileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, insertOp.getDest(), sliceIndex); + tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice, + position[1]); } // Insert the slice into the destination tile. @@ -642,23 +642,24 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern { auto loc = printOp.getLoc(); // Create a loop over the rows of the tile. - auto vscale = rewriter.create(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto minTileRows = - rewriter.create(loc, vectorType.getDimSize(0)); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, minTileRows, vscale); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); { // Loop body. rewriter.setInsertionPointToStart(forOp.getBody()); // Extract the current row from the tile. Value rowIndex = forOp.getInductionVar(); - auto tileSlice = rewriter.create( - loc, printOp.getSource(), rowIndex); + auto tileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, printOp.getSource(), rowIndex); // Print the row with a 1D vector.print. - rewriter.create(loc, tileSlice, - printOp.getPunctuation()); + vector::PrintOp::create(rewriter, loc, tileSlice, + printOp.getPunctuation()); } rewriter.eraseOp(printOp); @@ -707,8 +708,8 @@ struct FoldTransferWriteOfExtractTileSlice Value mask = writeOp.getMask(); if (!mask) { auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); - mask = rewriter.create( - writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); + mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType, + DenseElementsAttr::get(maskType, true)); } rewriter.replaceOpWithNewOp( @@ -776,10 +777,10 @@ struct ExtractFromCreateMaskToPselLowering // Create the two 1-D masks at the location of the 2-D create_mask (which is // usually outside a loop). This prevents the need for later hoisting. rewriter.setInsertionPoint(createMaskOp); - auto rowMask = rewriter.create( - loc, rowMaskType, createMaskOp.getOperand(0)); - auto colMask = rewriter.create( - loc, colMaskType, createMaskOp.getOperand(1)); + auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType, + createMaskOp.getOperand(0)); + auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType, + createMaskOp.getOperand(1)); rewriter.setInsertionPoint(extractOp); auto position = diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 9a8eb72d72925..77aab85483a8b 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -412,22 +412,22 @@ struct PrepareContractToGPUMMA if (maps == infer({{m, k}, {k, n}, {m, n}})) return rewriter.notifyMatchFailure(op, "contraction already prepared"); if (maps == infer({{m, k}, {n, k}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { @@ -494,13 +494,13 @@ struct CombineTransferReadOpTranspose final // Fuse through the integer extend op. if (extOp) { if (isa(extOp)) - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result) .getResult(); else if (isa(extOp)) - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result) .getResult(); else - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtFOp::create(rewriter, loc, op.getType(), result) .getResult(); } @@ -579,8 +579,8 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, } gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); - Value load = rewriter.create( - op.getLoc(), type, op.getBase(), op.getIndices(), + Value load = gpu::SubgroupMmaLoadMatrixOp::create( + rewriter, op.getLoc(), type, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; @@ -610,8 +610,8 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, } Value matrix = it->second; - auto store = rewriter.create( - op.getLoc(), matrix, op.getBase(), op.getIndices(), + auto store = gpu::SubgroupMmaStoreMatrixOp::create( + rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; @@ -661,8 +661,8 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, return rewriter.notifyMatchFailure(op, "not a splat"); } - Value result = rewriter.create( - op.getLoc(), vectorType, + Value result = arith::ConstantOp::create( + rewriter, op.getLoc(), vectorType, DenseElementsAttr::get(vectorType, dense.getSplatValue())); valueMapping[op.getResult()] = result; return success(); @@ -743,7 +743,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, } // Adjust the load offset. - auto laneId = rewriter.create(loc, /*upperBound=*/nullptr); + auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); FailureOr offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { @@ -757,8 +757,9 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, getXferIndices(rewriter, op, *offsets, {laneId}, indices); - nvgpu::LdMatrixOp newOp = rewriter.create( - loc, vectorType, op.getBase(), indices, *transpose, params->numTiles); + nvgpu::LdMatrixOp newOp = + nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(), + indices, *transpose, params->numTiles); valueMapping[op] = newOp->getResult(0); return success(); } @@ -782,17 +783,17 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, "conversion to distributed non-ldmatrix compatible load"); } - Value laneId = rewriter.create(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); // This is the individual element type. Type loadedElType = regInfo->registerLLVMType; VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value fill = rewriter.create( - op.getLoc(), vectorType.getElementType(), + Value fill = arith::ConstantOp::create( + rewriter, op.getLoc(), vectorType.getElementType(), rewriter.getZeroAttr(vectorType.getElementType())); Value result = - rewriter.create(op.getLoc(), vectorType, fill); + vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); @@ -809,16 +810,16 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); - Value logicalValueId = rewriter.create( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create(loc, loadedElType, - op.getBase(), newIndices); - result = rewriter.create(loc, el, result, i); + Value el = vector::LoadOp::create(rewriter, loc, loadedElType, + op.getBase(), newIndices); + result = vector::InsertOp::create(rewriter, loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { @@ -828,8 +829,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; innerIdx++) { - Value logicalValueId = rewriter.create( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); @@ -839,10 +840,10 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create(op.getLoc(), loadedElType, - op.getBase(), newIndices); - result = rewriter.create( - op.getLoc(), el, result, ArrayRef{i, innerIdx}); + Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType, + op.getBase(), newIndices); + result = vector::InsertOp::create(rewriter, op.getLoc(), el, result, + ArrayRef{i, innerIdx}); } } } @@ -916,11 +917,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "not mma sync reg info"); VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value laneId = rewriter.create(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { - Value logicalValueId = rewriter.create( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); @@ -928,11 +929,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "no coords"); Value el = - rewriter.create(loc, matrix, ArrayRef{i}); + vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef{i}); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - rewriter.create(loc, el, op.getBase(), newIndices); + vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); @@ -1015,8 +1016,8 @@ convertExtractStridedSlice(RewriterBase &rewriter, else if (offsets[1]) sliceOffset[0] = (warpVectorShape[1] / offsets[1]); - Value newOp = rewriter.create( - loc, sourceVector, sliceOffset, sliceShape, strides); + Value newOp = vector::ExtractStridedSliceOp::create( + rewriter, loc, sourceVector, sliceOffset, sliceShape, strides); valueMapping[op] = newOp; return success(); @@ -1035,9 +1036,10 @@ convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - Value matmul = rewriter.create( - op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), - /*b_transpose=*/UnitAttr()); + Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(), + opC.getType(), opA, opB, opC, + /*a_transpose=*/UnitAttr(), + /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; return success(); } @@ -1058,8 +1060,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, int64_t m = cast(op.getLhs().getType()).getShape()[0]; int64_t n = cast(op.getRhs().getType()).getShape()[0]; int64_t k = cast(op.getLhs().getType()).getShape()[1]; - Value matmul = rewriter.create( - op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); + Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, + rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); } @@ -1076,13 +1078,13 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, auto splat = cast(op.getValue()).getSplatValue(); auto scalarConstant = - rewriter.create(op.getLoc(), splat.getType(), splat); + arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create( - op.getLoc(), type, scalarConstant); + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), + type, scalarConstant); valueMapping[op.getResult()] = matrix; return success(); } @@ -1100,8 +1102,8 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, auto vecType = op.getResultVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create( - op.getLoc(), type, op.getSource()); + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), + type, op.getSource()); valueMapping[op.getResult()] = matrix; return success(); } @@ -1118,9 +1120,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getInitArgs()); llvm::append_range(operands, newInitArgs); - scf::ForOp newLoop = rewriter.create( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - operands); + scf::ForOp newLoop = + scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(), + loop.getUpperBound(), loop.getStep(), operands); rewriter.eraseBlock(newLoop.getBody()); newLoop.getRegion().getBlocks().splice( @@ -1189,7 +1191,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; yieldOperands.push_back(it->second); } - rewriter.create(op.getLoc(), yieldOperands); + scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); rewriter.eraseOp(op); @@ -1220,8 +1222,8 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op, resultType.getOperand()); } - Value newOp = rewriter.create( - op->getLoc(), resultType, matrixOperands, opType); + Value newOp = gpu::SubgroupMmaElementwiseOp::create( + rewriter, op->getLoc(), resultType, matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index e4ff770a807c6..9cd491caa9421 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -43,13 +43,13 @@ static Value insertOne(ConversionPatternRewriter &rewriter, assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create( - loc, typeConverter.convertType(idxType), + auto constant = LLVM::ConstantOp::create( + rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create(loc, llvmType, val1, val2, - constant); + return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2, + constant); } - return rewriter.create(loc, val1, val2, pos); + return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos); } // Helper that picks the proper sequence for extracting. @@ -58,13 +58,13 @@ static Value extractOne(ConversionPatternRewriter &rewriter, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank <= 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create( - loc, typeConverter.convertType(idxType), + auto constant = LLVM::ConstantOp::create( + rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create(loc, llvmType, val, - constant); + return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val, + constant); } - return rewriter.create(loc, val, pos); + return LLVM::ExtractValueOp::create(rewriter, loc, val, pos); } // Helper that returns data layout alignment of a vector. @@ -141,9 +141,9 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, auto ptrsType = LLVM::getVectorType(pType, vectorType.getDimSize(0), /*isScalable=*/vectorType.getScalableDims()[0]); - return rewriter.create( - loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), - base, index); + return LLVM::GEPOp::create( + rewriter, loc, ptrsType, + typeConverter.convertType(memRefType.getElementType()), base, index); } /// Convert `foldResult` into a Value. Integer attribute is converted to @@ -152,7 +152,7 @@ static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult) { if (auto attr = dyn_cast(foldResult)) { auto intAttr = cast(attr); - return builder.create(loc, intAttr).getResult(); + return LLVM::ConstantOp::create(builder, loc, intAttr).getResult(); } return cast(foldResult); @@ -440,32 +440,32 @@ class ReductionNeutralFPMax {}; static Value createReductionNeutralValue(ReductionNeutralZero neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create(loc, llvmType, - rewriter.getZeroAttr(llvmType)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getZeroAttr(llvmType)); } /// Create the reduction neutral integer one value. static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getIntegerAttr(llvmType, 1)); } /// Create the reduction neutral fp one value. static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getFloatAttr(llvmType, 1.0)); } /// Create the reduction neutral all-ones value. static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr( llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); } @@ -474,8 +474,8 @@ static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( llvmType.getIntOrFloatBitWidth()))); } @@ -484,8 +484,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( llvmType.getIntOrFloatBitWidth()))); } @@ -494,8 +494,8 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( llvmType.getIntOrFloatBitWidth()))); } @@ -504,8 +504,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( llvmType.getIntOrFloatBitWidth()))); } @@ -515,8 +515,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast(llvmType); - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), /*Negative=*/false))); @@ -527,8 +527,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast(llvmType); - return rewriter.create( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), /*Negative=*/true))); @@ -556,19 +556,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, auto vShape = vType.getShape(); assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); - Value baseVecLength = rewriter.create( - loc, rewriter.getI32Type(), + Value baseVecLength = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); if (!vType.getScalableDims()[0]) return baseVecLength; // For a scalable vector type, create and return `vScale * baseVecLength`. - Value vScale = rewriter.create(loc); + Value vScale = vector::VectorScaleOp::create(rewriter, loc); vScale = - rewriter.create(loc, rewriter.getI32Type(), vScale); + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale); Value scalableVecLength = - rewriter.create(loc, baseVecLength, vScale); + arith::MulIOp::create(rewriter, loc, baseVecLength, vScale); return scalableVecLength; } @@ -581,10 +581,11 @@ static Value createIntegerReductionArithmeticOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + Value result = + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) - result = rewriter.create(loc, accumulator, result); + result = ScalarOp::create(rewriter, loc, accumulator, result); return result; } @@ -596,11 +597,12 @@ template static Value createIntegerReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + Value result = + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) { Value cmp = - rewriter.create(loc, predicate, accumulator, result); - result = rewriter.create(loc, cmp, accumulator, result); + LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result); + result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result); } return result; } @@ -631,12 +633,11 @@ static Value createFPReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { Value result = - rewriter.create(loc, llvmType, vectorOperand, fmf); + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf); if (accumulator) { - result = - rewriter.create::Type>( - loc, result, accumulator); + result = VectorToScalarMapper::Type::create( + rewriter, loc, result, accumulator); } return result; @@ -667,7 +668,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, const auto &floatSemantics = cast(llvmType).getFloatSemantics(); auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); auto denseValue = DenseElementsAttr::get(cast(vectorType), value); - return rewriter.create(loc, vectorType, denseValue); + return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue); } /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked @@ -682,8 +683,8 @@ lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, Value mask, LLVM::FastmathFlagsAttr fmf) { const Value vectorMaskNeutral = createMaskNeutralValue( rewriter, loc, llvmType, vectorOperand.getType()); - const Value selectedVectorByMask = rewriter.create( - loc, mask, vectorOperand, vectorMaskNeutral); + const Value selectedVectorByMask = LLVM::SelectOp::create( + rewriter, loc, mask, vectorOperand, vectorMaskNeutral); return createFPReductionComparisonOpLowering( rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); } @@ -695,9 +696,9 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, Value accumulator, LLVM::FastmathFlagsAttr fmf) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, fmf); + return LLVMRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand, + fmf); } /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic @@ -710,9 +711,8 @@ lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, Value vectorOperand, Value accumulator) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand); + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand); } template @@ -723,9 +723,9 @@ static Value lowerPredicatedReductionWithStartValue( llvmType, accumulator); Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, mask, vectorLength); + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand, + mask, vectorLength); } template ( - loc, adaptor.getV1(), adaptor.getV2(), + Value llvmShuffleOp = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getV1(), adaptor.getV2(), llvm::to_vector_of(mask)); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); @@ -1050,7 +1050,7 @@ class VectorShuffleOpConversion eltType = arrayType.getElementType(); else eltType = cast(llvmType).getElementType(); - Value insert = rewriter.create(loc, llvmType); + Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType); int64_t insPos = 0; for (int64_t extPos : mask) { Value value = adaptor.getV1(); @@ -1087,9 +1087,9 @@ class VectorExtractElementOpConversion if (vectorType.getRank() == 0) { Location loc = extractEltOp.getLoc(); auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create( - loc, typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); + auto zero = LLVM::ConstantOp::create(rewriter, loc, + typeConverter->convertType(idxType), + rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( extractEltOp, llvmType, adaptor.getVector(), zero); return success(); @@ -1158,13 +1158,14 @@ class VectorExtractOpConversion if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); } - extracted = rewriter.create( - loc, extracted, getAsIntegers(position)); + extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted, + getAsIntegers(position)); } if (extractsScalar) { - extracted = rewriter.create( - loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back())); + extracted = LLVM::ExtractElementOp::create( + rewriter, loc, extracted, + getAsLLVMValue(rewriter, loc, positionVec.back())); } rewriter.replaceOp(extractOp, extracted); @@ -1221,9 +1222,9 @@ class VectorInsertElementOpConversion if (vectorType.getRank() == 0) { Location loc = insertEltOp.getLoc(); auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create( - loc, typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); + auto zero = LLVM::ConstantOp::create(rewriter, loc, + typeConverter->convertType(idxType), + rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); return success(); @@ -1307,8 +1308,8 @@ class VectorInsertOpConversion // llvm.extractvalue does not support dynamic dimensions. return failure(); } - sourceAggregate = rewriter.create( - loc, adaptor.getDest(), + sourceAggregate = LLVM::ExtractValueOp::create( + rewriter, loc, adaptor.getDest(), getAsIntegers(positionOf1DVectorWithinAggregate)); } else { // No-aggregate case. The destination for the InsertElementOp is just @@ -1316,16 +1317,16 @@ class VectorInsertOpConversion sourceAggregate = adaptor.getDest(); } // Insert the scalar into the 1D vector. - sourceAggregate = rewriter.create( - loc, sourceAggregate.getType(), sourceAggregate, + sourceAggregate = LLVM::InsertElementOp::create( + rewriter, loc, sourceAggregate.getType(), sourceAggregate, adaptor.getValueToStore(), getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector)); } Value result = sourceAggregate; if (isNestedAggregate) { - result = rewriter.create( - loc, adaptor.getDest(), sourceAggregate, + result = LLVM::InsertValueOp::create( + rewriter, loc, adaptor.getDest(), sourceAggregate, getAsIntegers(positionOf1DVectorWithinAggregate)); } @@ -1404,15 +1405,15 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern { auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, vType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { - Value extrLHS = rewriter.create(loc, op.getLhs(), i); - Value extrRHS = rewriter.create(loc, op.getRhs(), i); - Value extrACC = rewriter.create(loc, op.getAcc(), i); - Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); - desc = rewriter.create(loc, fma, desc, i); + Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i); + Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i); + Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i); + Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC); + desc = InsertOp::create(rewriter, loc, fma, desc, i); } rewriter.replaceOp(op, desc); return success(); @@ -1502,7 +1503,7 @@ class VectorTypeCastOpConversion desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); - auto zero = rewriter.create(loc, int64Ty, attr); + auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. @@ -1511,11 +1512,12 @@ class VectorTypeCastOpConversion int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); - auto size = rewriter.create(loc, int64Ty, sizeAttr); + auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), (*targetStrides)[index]); - auto stride = rewriter.create(loc, int64Ty, strideAttr); + auto stride = + LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } @@ -1543,14 +1545,15 @@ class VectorCreateMaskOpConversion IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); auto loc = op->getLoc(); - Value indices = rewriter.create( - loc, LLVM::getVectorType(idxType, dstType.getShape()[0], - /*isScalable=*/true)); + Value indices = LLVM::StepVectorOp::create( + rewriter, loc, + LLVM::getVectorType(idxType, dstType.getShape()[0], + /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, adaptor.getOperands()[0]); - Value bounds = rewriter.create(loc, indices.getType(), bound); - Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, - indices, bounds); + Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound); + Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + indices, bounds); rewriter.replaceOp(op, comp); return success(); } @@ -1706,16 +1709,16 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 64), value); + value = arith::ExtUIOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::SignExt64: - value = rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 64), value); + value = arith::ExtSIOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::Bitcast16: - value = rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 16), value); + value = LLVM::BitcastOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value); break; case PrintConversion::None: break; @@ -1727,8 +1730,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { - rewriter.create(loc, TypeRange(), SymbolRefAttr::get(ref), - params); + LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref), + params); } }; @@ -1754,9 +1757,9 @@ struct VectorBroadcastScalarToLowRankLowering // First insert it into a poison vector so we can shuffle it. auto vectorType = typeConverter->convertType(broadcast.getType()); Value poison = - rewriter.create(broadcast.getLoc(), vectorType); - auto zero = rewriter.create( - broadcast.getLoc(), + LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType); + auto zero = LLVM::ConstantOp::create( + rewriter, broadcast.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); @@ -1768,8 +1771,9 @@ struct VectorBroadcastScalarToLowRankLowering } // For 1-d vector, we additionally do a `vectorshuffle`. - auto v = rewriter.create( - broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + auto v = + LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, + poison, adaptor.getSource(), zero); int64_t width = cast(broadcast.getType()).getDimSize(0); SmallVector zeroValues(width, 0); @@ -1811,26 +1815,26 @@ struct VectorBroadcastScalarToNdLowering return failure(); // Construct returned value. - Value desc = rewriter.create(loc, llvmNDVectorTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy); // Construct a 1-D vector with the broadcasted value that we insert in all // the places within the returned descriptor. - Value vdesc = rewriter.create(loc, llvm1DVectorTy); - auto zero = rewriter.create( - loc, typeConverter->convertType(rewriter.getIntegerType(32)), + Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy); + auto zero = LLVM::ConstantOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, - adaptor.getSource(), zero); + Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy, + vdesc, adaptor.getSource(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); - v = rewriter.create(loc, v, v, zeroValues); + v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues); // Iterate of linear index, convert to coords space and insert broadcasted // 1-D vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { - desc = rewriter.create(loc, desc, v, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position); }); rewriter.replaceOp(broadcast, desc); return success(); @@ -1900,13 +1904,13 @@ struct VectorDeinterleaveOpLowering auto deinterleaveResults = deinterleaveOp.getResultTypes(); auto packedOpResults = llvmTypeConverter->packOperationResults(deinterleaveResults); - auto intrinsic = rewriter.create( - loc, packedOpResults, adaptor.getSource()); + auto intrinsic = LLVM::vector_deinterleave2::create( + rewriter, loc, packedOpResults, adaptor.getSource()); - auto evenResult = rewriter.create( - loc, intrinsic->getResult(0), 0); - auto oddResult = rewriter.create( - loc, intrinsic->getResult(0), 1); + auto evenResult = LLVM::ExtractValueOp::create( + rewriter, loc, intrinsic->getResult(0), 0); + auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc, + intrinsic->getResult(0), 1); rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult}); return success(); @@ -1929,11 +1933,11 @@ struct VectorDeinterleaveOpLowering oddShuffleMask.push_back(i); } - auto poison = rewriter.create(loc, sourceType); - auto evenShuffle = rewriter.create( - loc, adaptor.getSource(), poison, evenShuffleMask); - auto oddShuffle = rewriter.create( - loc, adaptor.getSource(), poison, oddShuffleMask); + auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType); + auto evenShuffle = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getSource(), poison, evenShuffleMask); + auto oddShuffle = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getSource(), poison, oddShuffleMask); rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle}); return success(); @@ -1956,9 +1960,9 @@ struct VectorFromElementsLowering return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); - Value result = rewriter.create(loc, llvmType); + Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = rewriter.create(loc, val, result, idx); + result = vector::InsertOp::create(rewriter, loc, val, result, idx); rewriter.replaceOp(fromElementsOp, result); return success(); } @@ -1982,12 +1986,12 @@ struct VectorToElementsLowering if (element.use_empty()) continue; - auto constIdx = rewriter.create( - loc, idxType, rewriter.getIntegerAttr(idxType, idx)); + auto constIdx = LLVM::ConstantOp::create( + rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx)); auto llvmType = typeConverter->convertType(element.getType()); - Value result = rewriter.create(loc, llvmType, - source, constIdx); + Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType, + source, constIdx); results[idx] = result; } @@ -2098,7 +2102,7 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Value lhs = op.getLhs(); auto lhsMap = op.getIndexingMapsArray()[0]; if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) - lhs = rew.create(loc, lhs, ArrayRef{1, 0}); + lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef{1, 0}); else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) return failure(); @@ -2106,7 +2110,7 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Value rhs = op.getRhs(); auto rhsMap = op.getIndexingMapsArray()[1]; if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) - rhs = rew.create(loc, rhs, ArrayRef{1, 0}); + rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef{1, 0}); else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) return failure(); @@ -2119,20 +2123,20 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Type flattenedLHSType = VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - lhs = rew.create(loc, flattenedLHSType, lhs); + lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs); Type flattenedRHSType = VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - rhs = rew.create(loc, flattenedRHSType, rhs); + rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs); - Value mul = rew.create( - loc, + Value mul = LLVM::MatrixMultiplyOp::create( + rew, loc, VectorType::get(lhsRows * rhsColumns, cast(lhs.getType()).getElementType()), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rew.create( - loc, + mul = vector::ShapeCastOp::create( + rew, loc, VectorType::get({lhsRows, rhsColumns}, getElementTypeOrSelf(op.getAcc().getType())), mul); @@ -2140,15 +2144,15 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( // ACC must be C(m, n) or C(n, m). auto accMap = op.getIndexingMapsArray()[2]; if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) - mul = rew.create(loc, mul, ArrayRef{1, 0}); + mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef{1, 0}); else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) llvm_unreachable("invalid contraction semantics"); - Value res = - isa(elementType) - ? static_cast(rew.create(loc, op.getAcc(), mul)) - : static_cast( - rew.create(loc, op.getAcc(), mul)); + Value res = isa(elementType) + ? static_cast( + arith::AddIOp::create(rew, loc, op.getAcc(), mul)) + : static_cast( + arith::AddFOp::create(rew, loc, op.getAcc(), mul)); return res; } @@ -2181,11 +2185,11 @@ class TransposeOpToMatrixTransposeOpLowering Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); auto matrix = - rewriter.create(loc, flattenedType, input); + vector::ShapeCastOp::create(rewriter, loc, flattenedType, input); auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); - Value trans = rewriter.create( - loc, flattenedType, matrix, rows, columns); + Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType, + matrix, rows, columns); rewriter.replaceOpWithNewOp(op, resType, trans); return success(); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 43732f58a4e0a..4c1047a8871a5 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -132,9 +132,9 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, Value value) { if (hasRetVal) { assert(value && "Expected non-empty value"); - b.create(loc, value); + scf::YieldOp::create(b, loc, value); } else { - b.create(loc); + scf::YieldOp::create(b, loc); } } @@ -154,7 +154,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { return Value(); Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getMask(), iv); + return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv); } /// Helper function TransferOpConversion and TransferOp1dConversion. @@ -201,22 +201,22 @@ static Value generateInBoundsCheck( Value base = xferOp.getIndices()[*dim]; Value memrefIdx = affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, - memrefIdx); + cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim, + memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create(cond, maskCond); + cond = arith::AndIOp::create(lb, cond, maskCond); else cond = maskCond; } // If the condition is non-empty, generate an SCF::IfOp. if (cond) { - auto check = lb.create( - cond, + auto check = scf::IfOp::create( + lb, cond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); @@ -226,7 +226,7 @@ static Value generateInBoundsCheck( if (outOfBoundsCase) { maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); } else { - b.create(loc); + scf::YieldOp::create(b, loc); } }); @@ -303,14 +303,15 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { BufferAllocs result; auto bufferType = MemRefType::get({}, xferOp.getVectorType()); - result.dataBuffer = b.create(loc, bufferType); + result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType); if (xferOp.getMask()) { auto maskType = MemRefType::get({}, xferOp.getMask().getType()); - auto maskBuffer = b.create(loc, maskType); + auto maskBuffer = memref::AllocaOp::create(b, loc, maskType); b.setInsertionPoint(xferOp); - b.create(loc, xferOp.getMask(), maskBuffer); - result.maskBuffer = b.create(loc, maskBuffer, ValueRange()); + memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer); + result.maskBuffer = + memref::LoadOp::create(b, loc, maskBuffer, ValueRange()); } return result; @@ -421,14 +422,15 @@ struct Strategy { auto bufferType = dyn_cast(buffer.getType()); auto vecType = dyn_cast(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create( - loc, vecType, xferOp.getBase(), xferIndices, + auto newXferOp = vector::TransferReadOp::create( + b, loc, vecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeApplyPassLabel(b, newXferOp, options.targetRank); - b.create(loc, newXferOp.getVector(), buffer, storeIndices); + memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer, + storeIndices); return newXferOp; } @@ -444,8 +446,9 @@ struct Strategy { Location loc = xferOp.getLoc(); auto bufferType = dyn_cast(buffer.getType()); auto vecType = dyn_cast(bufferType.getElementType()); - auto vec = b.create(loc, vecType, xferOp.getPadding()); - b.create(loc, vec, buffer, storeIndices); + auto vec = + vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding()); + memref::StoreOp::create(b, loc, vec, buffer, storeIndices); return Value(); } @@ -506,12 +509,12 @@ struct Strategy { getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto vec = b.create(loc, buffer, loadIndices); + auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto source = loopState.empty() ? xferOp.getBase() : loopState[0]; Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); - auto newXferOp = b.create( - loc, type, vec, source, xferIndices, + auto newXferOp = vector::TransferWriteOp::create( + b, loc, type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -610,8 +613,8 @@ struct PrepareTransferReadConversion } Location loc = xferOp.getLoc(); - rewriter.create(loc, newXfer->getResult(0), - buffers.dataBuffer); + memref::StoreOp::create(rewriter, loc, newXfer->getResult(0), + buffers.dataBuffer); rewriter.replaceOpWithNewOp(xferOp, buffers.dataBuffer); return success(); @@ -653,9 +656,9 @@ struct PrepareTransferWriteConversion Location loc = xferOp.getLoc(); auto buffers = allocBuffers(rewriter, xferOp); - rewriter.create(loc, xferOp.getVector(), - buffers.dataBuffer); - auto loadedVec = rewriter.create(loc, buffers.dataBuffer); + memref::StoreOp::create(rewriter, loc, xferOp.getVector(), + buffers.dataBuffer); + auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer); rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getValueToStoreMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); @@ -735,17 +738,17 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { auto signlessTargetVectorType = vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); auto targetVectorType = vectorType.cloneWith({}, legalIntTy); - value = rewriter.create(loc, signlessSourceVectorType, - value); + value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType, + value); if (value.getType() != signlessTargetVectorType) { if (width == 1 || intTy.isUnsigned()) - value = rewriter.create(loc, signlessTargetVectorType, - value); + value = arith::ExtUIOp::create(rewriter, loc, + signlessTargetVectorType, value); else - value = rewriter.create(loc, signlessTargetVectorType, - value); + value = arith::ExtSIOp::create(rewriter, loc, + signlessTargetVectorType, value); } - value = rewriter.create(loc, targetVectorType, value); + value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value); vectorType = targetVectorType; } @@ -762,29 +765,30 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { std::multiplies()); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); - value = rewriter.create(loc, flatVectorType, value); + value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); } vector::PrintOp firstClose; SmallVector loopIndices; for (unsigned d = 0; d < shape.size(); d++) { // Setup loop bounds and step. - Value lowerBound = rewriter.create(loc, 0); - Value upperBound = rewriter.create(loc, shape[d]); - Value step = rewriter.create(loc, 1); + Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value upperBound = + arith::ConstantIndexOp::create(rewriter, loc, shape[d]); + Value step = arith::ConstantIndexOp::create(rewriter, loc, 1); if (!scalableDimensions.empty() && scalableDimensions[d]) { - auto vscale = rewriter.create( - loc, rewriter.getIndexType()); - upperBound = rewriter.create(loc, upperBound, vscale); + auto vscale = vector::VectorScaleOp::create(rewriter, loc, + rewriter.getIndexType()); + upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale); } - auto lastIndex = rewriter.create(loc, upperBound, step); + auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step); // Create a loop to print the elements surrounded by parentheses. - rewriter.create(loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); auto loop = - rewriter.create(loc, lowerBound, upperBound, step); - auto printClose = rewriter.create( - loc, vector::PrintPunctuation::Close); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); + auto printClose = vector::PrintOp::create( + rewriter, loc, vector::PrintPunctuation::Close); if (!firstClose) firstClose = printClose; @@ -793,14 +797,14 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { // Print a comma after all but the last element. rewriter.setInsertionPointToStart(loop.getBody()); - auto notLastIndex = rewriter.create( - loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); - rewriter.create(loc, notLastIndex, - [&](OpBuilder &builder, Location loc) { - builder.create( - loc, vector::PrintPunctuation::Comma); - builder.create(loc); - }); + auto notLastIndex = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); + scf::IfOp::create(rewriter, loc, notLastIndex, + [&](OpBuilder &builder, Location loc) { + vector::PrintOp::create( + builder, loc, vector::PrintPunctuation::Comma); + scf::YieldOp::create(builder, loc); + }); rewriter.setInsertionPointToStart(loop.getBody()); } @@ -810,22 +814,23 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { Value flatIndex; auto currentStride = 1; for (int d = shape.size() - 1; d >= 0; d--) { - auto stride = rewriter.create(loc, currentStride); - auto index = rewriter.create(loc, stride, loopIndices[d]); + auto stride = + arith::ConstantIndexOp::create(rewriter, loc, currentStride); + auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]); if (flatIndex) - flatIndex = rewriter.create(loc, flatIndex, index); + flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index); else flatIndex = index; currentStride *= shape[d]; } // Print the scalar elements in the inner most loop. - auto element = rewriter.create(loc, value, flatIndex); - rewriter.create(loc, element, - vector::PrintPunctuation::NoPunctuation); + auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex); + vector::PrintOp::create(rewriter, loc, element, + vector::PrintPunctuation::NoPunctuation); rewriter.setInsertionPointAfter(firstClose); - rewriter.create(loc, printOp.getPunctuation()); + vector::PrintOp::create(rewriter, loc, printOp.getPunctuation()); rewriter.eraseOp(printOp); return success(); } @@ -916,7 +921,7 @@ struct TransferOpConversion : public VectorToSCFPattern { "Failed to unpack one vector dim."); auto castedDataBuffer = - locB.create(*castedDataType, dataBuffer); + vector::TypeCastOp::create(locB, *castedDataType, dataBuffer); // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; @@ -935,22 +940,22 @@ struct TransferOpConversion : public VectorToSCFPattern { auto maskBufferType = cast(maskBuffer.getType()); MemRefType castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = - locB.create(castedMaskType, maskBuffer); + vector::TypeCastOp::create(locB, castedMaskType, maskBuffer); } } // Loop bounds and step. - auto lb = locB.create(0); - auto ub = locB.create( - castedDataType->getDimSize(castedDataType->getRank() - 1)); - auto step = locB.create(1); + auto lb = arith::ConstantIndexOp::create(locB, 0); + auto ub = arith::ConstantIndexOp::create( + locB, castedDataType->getDimSize(castedDataType->getRank() - 1)); + auto step = arith::ConstantIndexOp::create(locB, 1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy::initialLoopState(xferOp); // Generate for loop. - auto result = locB.create( - lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), + auto result = scf::ForOp::create( + locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { Type stateType = loopState.empty() ? Type() : loopState[0].getType(); @@ -975,8 +980,8 @@ struct TransferOpConversion : public VectorToSCFPattern { SmallVector loadIndices; getMaskBufferLoadIndices(xferOp, castedMaskBuffer, loadIndices, iv); - auto mask = b.create(loc, castedMaskBuffer, - loadIndices); + auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer, + loadIndices); rewriter.modifyOpInPlace(newXfer, [&]() { newXfer.getMaskMutable().assign(mask); }); @@ -1119,30 +1124,30 @@ struct ScalableTransposeTransferWriteConversion auto transposeSource = transposeOp.getVector(); SmallVector transposeSourceSlices = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return rewriter.create(loc, transposeSource, idx); + return vector::ExtractOp::create(rewriter, loc, transposeSource, idx); }); // Loop bounds and step. - auto lb = rewriter.create(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); auto ub = maskDims->empty() ? Value(createVscaleMultiple(vectorType.getDimSize(0))) : vector::getAsValues(rewriter, loc, maskDims->front()).front(); - auto step = rewriter.create(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); // Generate a new mask for the slice. VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); Value sliceMask = nullptr; if (!maskDims->empty()) { - sliceMask = rewriter.create( - loc, sliceType.clone(rewriter.getI1Type()), + sliceMask = vector::CreateMaskOp::create( + rewriter, loc, sliceType.clone(rewriter.getI1Type()), ArrayRef(*maskDims).drop_front()); } Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{}; ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; - auto result = rewriter.create( - loc, lb, ub, step, initLoopArgs, + auto result = scf::ForOp::create( + rewriter, loc, lb, ub, step, initLoopArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { // Indices for the new transfer op. SmallVector xferIndices; @@ -1151,25 +1156,25 @@ struct ScalableTransposeTransferWriteConversion // Extract a transposed slice from the source vector. SmallVector transposeElements = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return b.create( - loc, transposeSourceSlices[idx], iv); + return vector::ExtractOp::create( + b, loc, transposeSourceSlices[idx], iv); }); - auto sliceVec = b.create(loc, sliceType, - transposeElements); + auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType, + transposeElements); // Create the transfer_write for the slice. Value dest = loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front(); - auto newWriteOp = b.create( - loc, sliceVec, dest, xferIndices, + auto newWriteOp = vector::TransferWriteOp::create( + b, loc, sliceVec, dest, xferIndices, ArrayRef(writeOp.getInBoundsValues()).drop_front()); if (sliceMask) newWriteOp.getMaskMutable().assign(sliceMask); // Yield from the loop. - b.create(loc, loopIterArgs.empty() - ? ValueRange{} - : newWriteOp.getResult()); + scf::YieldOp::create(b, loc, + loopIterArgs.empty() ? ValueRange{} + : newWriteOp.getResult()); }); if (isTensorOp(writeOp)) @@ -1207,7 +1212,7 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, llvm::SmallVector indices({i}); Location loc = xferOp.getLoc(); - auto newMask = b.create(loc, xferOp.getMask(), indices); + auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices); newXferOp.getMaskMutable().assign(newMask); } @@ -1261,8 +1266,8 @@ struct UnrollTransferReadConversion if (auto insertOp = getInsertOp(xferOp)) return insertOp.getDest(); Location loc = xferOp.getLoc(); - return rewriter.create(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(), + xferOp.getPadding()); } /// If the result of the TransferReadOp has exactly one user, which is a @@ -1317,7 +1322,7 @@ struct UnrollTransferReadConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); // FIXME: Rename this lambda - it does much more than just // in-bounds-check generation. @@ -1336,8 +1341,8 @@ struct UnrollTransferReadConversion auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create( - loc, newXferVecType, xferOp.getBase(), xferIndices, + auto newXferOp = vector::TransferReadOp::create( + b, loc, newXferVecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeAssignMask(b, xferOp, newXferOp, i); @@ -1346,11 +1351,11 @@ struct UnrollTransferReadConversion if (newXferVecType.getRank() == 0) { // vector.insert does not accept rank-0 as the non-indexed // argument. Extract the scalar before inserting. - valToInser = b.create(loc, valToInser, - SmallVector()); + valToInser = vector::ExtractOp::create(b, loc, valToInser, + SmallVector()); } - return b.create(loc, valToInser, vec, - insertionIndices); + return vector::InsertOp::create(b, loc, valToInser, vec, + insertionIndices); }, /*outOfBoundsCase=*/ [&](OpBuilder &b, Location loc) { @@ -1460,7 +1465,7 @@ struct UnrollTransferWriteConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1477,20 +1482,20 @@ struct UnrollTransferWriteConversion extractionIndices.push_back(b.getI64IntegerAttr(i)); auto extracted = - b.create(loc, vec, extractionIndices); + vector::ExtractOp::create(b, loc, vec, extractionIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); Value xferVec; if (inputVectorTy.getRank() == 1) { // When target-rank=0, unrolling would causes the vector input // argument into `transfer_write` to become a scalar. We solve // this by broadcasting the scalar to a 0D vector. - xferVec = b.create( - loc, VectorType::get({}, extracted.getType()), extracted); + xferVec = vector::BroadcastOp::create( + b, loc, VectorType::get({}, extracted.getType()), extracted); } else { xferVec = extracted; } - auto newXferOp = b.create( - loc, sourceType, xferVec, source, xferIndices, + auto newXferOp = vector::TransferWriteOp::create( + b, loc, sourceType, xferVec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -1572,19 +1577,19 @@ struct Strategy1d { b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { - Value val = b.create(loc, xferOp.getBase(), indices); - return b.create(loc, val, vec, iv); + Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices); + return vector::InsertOp::create(b, loc, val, vec, iv); }, /*outOfBoundsCase=*/ [&](OpBuilder & /*b*/, Location loc) { return vec; }); - b.create(loc, nextVec); + scf::YieldOp::create(b, loc, nextVec); } static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(), + xferOp.getPadding()); } }; @@ -1601,10 +1606,10 @@ struct Strategy1d { generateInBoundsCheck( b, xferOp, iv, dim, /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { - auto val = b.create(loc, xferOp.getVector(), iv); - b.create(loc, val, xferOp.getBase(), indices); + auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv); + memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices); }); - b.create(loc); + scf::YieldOp::create(b, loc); } static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { @@ -1665,15 +1670,15 @@ struct TransferOp1dConversion : public VectorToSCFPattern { // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); Value ub = - rewriter.create(loc, vecType.getDimSize(0)); + arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0)); if (vecType.isScalable()) { Value vscale = - rewriter.create(loc, rewriter.getIndexType()); - ub = rewriter.create(loc, ub, vscale); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + ub = arith::MulIOp::create(rewriter, loc, ub, vscale); } - auto step = rewriter.create(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loopState = Strategy1d::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 750ce85049409..00ee3faa908e1 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -161,19 +161,19 @@ static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter, Location loc, Value dynamicIndex, int64_t kPoisonIndex, unsigned vectorSize) { if (llvm::isPowerOf2_32(vectorSize)) { - Value inBoundsMask = rewriter.create( - loc, dynamicIndex.getType(), + Value inBoundsMask = spirv::ConstantOp::create( + rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1)); - return rewriter.create(loc, dynamicIndex, - inBoundsMask); + return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex, + inBoundsMask); } - Value poisonIndex = rewriter.create( - loc, dynamicIndex.getType(), + Value poisonIndex = spirv::ConstantOp::create( + rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex)); Value cmpResult = - rewriter.create(loc, dynamicIndex, poisonIndex); - return rewriter.create( - loc, cmpResult, + spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex); + return spirv::SelectOp::create( + rewriter, loc, cmpResult, spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter), dynamicIndex); } @@ -441,8 +441,8 @@ static SmallVector extractAllElements( Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { - values.push_back(rewriter.create( - loc, srcVectorType.getElementType(), adaptor.getVector(), + values.push_back(spirv::CompositeExtractOp::create( + rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(), rewriter.getI32ArrayAttr({i}))); } if (Value acc = adaptor.getAcc()) @@ -495,16 +495,16 @@ struct VectorReductionPattern final : OpConversionPattern { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (llvm::isa(resultType)) { \ - result = rewriter.create(loc, resultType, result, next); \ + result = spirv::iop::create(rewriter, loc, resultType, result, next); \ } else { \ assert(llvm::isa(resultType)); \ - result = rewriter.create(loc, resultType, result, next); \ + result = spirv::fop::create(rewriter, loc, resultType, result, next); \ } \ break #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); @@ -551,7 +551,7 @@ struct VectorReductionFloatMinMax final #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); @@ -632,8 +632,8 @@ struct VectorShuffleOpConvert final auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast(scalarOrVec.getType())) - return rewriter.create(loc, scalarOrVec, - idx); + return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec, + idx); assert(idx == 0 && "Invalid scalar element index"); return scalarOrVec; @@ -731,11 +731,13 @@ struct VectorDeinterleaveOpConvert final // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to // use `spirv::CompositeExtractOp`. if (n == 2) { - auto elem0 = rewriter.create( - loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); + auto elem0 = spirv::CompositeExtractOp::create( + rewriter, loc, newResultType, sourceVector, + rewriter.getI32ArrayAttr({0})); - auto elem1 = rewriter.create( - loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); + auto elem1 = spirv::CompositeExtractOp::create( + rewriter, loc, newResultType, sourceVector, + rewriter.getI32ArrayAttr({1})); rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); return success(); @@ -752,12 +754,12 @@ struct VectorDeinterleaveOpConvert final llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); // Create two SPIR-V shuffles. - auto shuffleEven = rewriter.create( - loc, newResultType, sourceVector, sourceVector, + auto shuffleEven = spirv::VectorShuffleOp::create( + rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesEven)); - auto shuffleOdd = rewriter.create( - loc, newResultType, sourceVector, sourceVector, + auto shuffleOdd = spirv::VectorShuffleOp::create( + rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesOdd)); rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd}); @@ -798,10 +800,11 @@ struct VectorLoadOpConverter final // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. - Value castedAccessChain = (vectorType.getNumElements() == 1) - ? accessChain - : rewriter.create( - loc, vectorPtrType, accessChain); + Value castedAccessChain = + (vectorType.getNumElements() == 1) + ? accessChain + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, + accessChain); rewriter.replaceOpWithNewOp(loadOp, spirvVectorType, castedAccessChain); @@ -840,10 +843,11 @@ struct VectorStoreOpConverter final // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. - Value castedAccessChain = (vectorType.getNumElements() == 1) - ? accessChain - : rewriter.create( - loc, vectorPtrType, accessChain); + Value castedAccessChain = + (vectorType.getNumElements() == 1) + ? accessChain + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, + accessChain); rewriter.replaceOpWithNewOp(storeOp, castedAccessChain, adaptor.getValueToStore()); @@ -924,10 +928,10 @@ struct VectorReductionToIntDotProd final auto v4i8Type = VectorType::get({4}, i8Type); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); - lhsIn = rewriter.create( - loc, v4i8Type, ValueRange{lhsIn, zero}); - rhsIn = rewriter.create( - loc, v4i8Type, ValueRange{rhsIn, zero}); + lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, + ValueRange{lhsIn, zero}); + rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, + ValueRange{rhsIn, zero}); } // There's no variant of dot prod ops for unsigned LHS and signed RHS, so @@ -990,14 +994,14 @@ struct VectorReductionToFPDotProd final Attribute oneAttr = rewriter.getFloatAttr(vectorType.getElementType(), 1.0); oneAttr = SplatElementsAttr::get(vectorType, oneAttr); - rhs = rewriter.create(loc, vectorType, oneAttr); + rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr); } assert(lhs); assert(rhs); - Value res = rewriter.create(loc, resultType, lhs, rhs); + Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs); if (acc) - res = rewriter.create(loc, acc, res); + res = spirv::FAddOp::create(rewriter, loc, acc, res); rewriter.replaceOp(op, res); return success(); @@ -1032,7 +1036,8 @@ struct VectorStepOpConvert final : OpConversionPattern { source.reserve(numElements); for (int64_t i = 0; i < numElements; ++i) { Attribute intAttr = rewriter.getIntegerAttr(intType, i); - Value constOp = rewriter.create(loc, intType, intAttr); + Value constOp = + spirv::ConstantOp::create(rewriter, loc, intType, intAttr); source.push_back(constOp); } rewriter.replaceOpWithNewOp(stepOp, dstType, @@ -1075,8 +1080,8 @@ struct VectorToElementOpConvert final if (element.use_empty()) continue; - Value result = rewriter.create( - loc, elementType, adaptor.getSource(), + Value result = spirv::CompositeExtractOp::create( + rewriter, loc, elementType, adaptor.getSource(), rewriter.getI32ArrayAttr({static_cast(idx)})); results[idx] = result; } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 2e6a16ddbfdaa..80107554144cf 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -108,15 +108,15 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { - ndDesc = rewriter.create(loc, descType, src, - getAsOpFoldResult(offsets)); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, + getAsOpFoldResult(offsets)); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. SmallVector sourceDims; unsigned srcRank = srcTy.getRank(); for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(rewriter.create(loc, src, i)); + sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); SmallVector constOffsets; SmallVector dynOffsets; @@ -135,18 +135,18 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, // Compute strides in reverse order. SmallVector dynStrides; - Value accStride = rewriter.create(loc, 1); + Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); // Last stride is guaranteed to be static and unit. for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { accStride = - rewriter.create(loc, accStride, sourceDims[i + 1]); + arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); if (strides[i] == ShapedType::kDynamic) dynStrides.push_back(accStride); } std::reverse(dynStrides.begin(), dynStrides.end()); - ndDesc = rewriter.create( - loc, descType, src, dynOffsets, dynShapes, dynStrides, + ndDesc = xegpu::CreateNdDescOp::create( + rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), DenseI64ArrayAttr::get(rewriter.getContext(), strides)); @@ -200,10 +200,10 @@ struct TransferReadLowering : public OpRewritePattern { ArrayRef{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = rewriter.create( - loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + /*packed=*/nullptr, transposeAttr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(readOp, loadOp); return success(); @@ -238,9 +238,9 @@ struct TransferWriteLowering // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeOp = - rewriter.create(loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -269,8 +269,8 @@ struct LoadLowering : public OpRewritePattern { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = rewriter.create( - loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, + auto loadNdOp = xegpu::LoadNdOp::create( + rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(loadOp, loadNdOp); @@ -303,9 +303,9 @@ struct StoreLowering : public OpRewritePattern { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeNdOp = - rewriter.create(loc, vector, ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(storeOp, storeNdOp); return success(); @@ -339,8 +339,9 @@ struct ContractionLowering : public OpRewritePattern { if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); - auto dpasOp = rewriter.create( - loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); + auto dpasOp = xegpu::DpasOp::create(rewriter, loc, + TypeRange{contractOp.getResultType()}, + ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); return success(); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index a8380b9669f0f..2411af043f3f7 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -251,7 +251,7 @@ static LLVM::CallOp createDeviceFunctionCall( for (auto [idx, attrName] : paramAttrs) funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); - auto callOp = rewriter.create(loc, funcOp, args); + auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args); callOp->setAttrs(funcOp->getAttrs()); return callOp; @@ -299,7 +299,7 @@ class MMAToOCLPattern : public OpConversionPattern { VectorType newTy = VectorType::get( vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); if (origTy != newTy) - val = rewriter.create(loc, newTy, val); + val = LLVM::BitcastOp::create(rewriter, loc, newTy, val); return val; }; @@ -326,7 +326,7 @@ class MMAToOCLPattern : public OpConversionPattern { : cOrigTy; VectorType resTy = cTy; if (cOrigTy != cTy) - c = rewriter.create(loc, cTy, c); + c = LLVM::BitcastOp::create(rewriter, loc, cTy, c); constexpr int32_t systolicDepth{8}; std::string fnName = @@ -352,7 +352,7 @@ class MMAToOCLPattern : public OpConversionPattern { ->getResult(0); if (resOrigTy != resTy) - result = rewriter.create(loc, resOrigTy, result); + result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result); rewriter.replaceOp(op, result); return success(); @@ -383,7 +383,7 @@ class PrefetchToOCLPattern : public OpConversionPattern { auto loc = op.getLoc(); const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; Value one = - rewriter.create(loc, rewriter.getI64Type(), 1); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1); SmallVector args{op.getPtr(), one}; SmallVector argTypes; for (auto arg : args) @@ -439,11 +439,11 @@ class MemfenceToOCLPattern : public OpConversionPattern { op, "Fence only supports workgroup and device memory scopes."); } Type i32Type = rewriter.getI32Type(); - Value acqRel = rewriter.create(loc, i32Type, 4); + Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4); Value memScopeConst = - rewriter.create(loc, i32Type, memScope); + LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope); Value addrSpaceConst = - rewriter.create(loc, i32Type, addrSpace); + LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace); SmallVector args{addrSpaceConst, acqRel, memScopeConst}; SmallVector argTypes{3, i32Type}; createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), @@ -477,13 +477,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { auto i32Type = rewriter.getI32Type(); Value byteCoord = - rewriter.create(loc, VectorType::get(2, i32Type)); - Value zero = rewriter.create(loc, i32Type, 0); - Value one = rewriter.create(loc, i32Type, 1); - byteCoord = rewriter.create( - loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); - byteCoord = rewriter.create( - loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), op.getBasePitch(), byteCoord}; SmallVector retTypes; @@ -504,11 +504,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { } else { auto vecElemType = vecType.getElementType(); auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); - Value numElems = rewriter.create( - loc, i32Type, vecType.getNumElements()); - auto dstOrSrcPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, - numElems); + Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type, + vecType.getNumElements()); + auto dstOrSrcPtr = LLVM::AllocaOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + vecElemType, numElems); args.push_back(dstOrSrcPtr); if constexpr (isLoad) { // Load funcName += "read"; @@ -530,7 +530,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { bitWidthId = (vecElemBitWidth == 32) ? "j" : ((vecElemBitWidth == 16) ? "t" : "h"); - rewriter.create(loc, op.getStoredVal(), dstOrSrcPtr); + LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr); paramAttrs = { std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), @@ -563,7 +563,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { } if constexpr (isLoad) rewriter.replaceOp( - op, rewriter.create(loc, vecType, spvLoadDstPtr)); + op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr)); else rewriter.eraseOp(op); return success();