diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp index 0473bb59fa6aa..99d2f6ca78c38 100644 --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value posOne = rewriter.create(loc, n.getType(), 1); - Value negOne = rewriter.create(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mPos = - rewriter.create(loc, LLVM::ICmpPredicate::sgt, m, zero); - Value x = rewriter.create(loc, mPos, negOne, posOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create(loc, n, x); - Value nPlusXDivM = rewriter.create(loc, nPlusX, m); - Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create(loc, zero, n); - Value negNDivM = rewriter.create(loc, negN, m); - Value negRes = rewriter.create(loc, zero, negNDivM); + Value negN = LLVM::SubOp::create(rewriter, loc, zero, n); + Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. Value nPos = - rewriter.create(loc, LLVM::ICmpPredicate::sgt, n, zero); - Value sameSign = - rewriter.create(loc, LLVM::ICmpPredicate::eq, nPos, mPos); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero); + Value sameSign = LLVM::ICmpOp::create(rewriter, loc, + LLVM::ICmpPredicate::eq, nPos, mPos); Value nNonZero = - rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create(loc, sameSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } @@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value one = rewriter.create(loc, n.getType(), 1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); // Compute the non-zero result. - Value minusOne = rewriter.create(loc, n, one); - Value quotient = rewriter.create(loc, minusOne, m); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one); + Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one); // Pick the result. Value cmp = - rewriter.create(loc, LLVM::ICmpPredicate::eq, n, zero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } @@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value posOne = rewriter.create(loc, n.getType(), 1); - Value negOne = rewriter.create(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mNeg = - rewriter.create(loc, LLVM::ICmpPredicate::slt, m, zero); - Value x = rewriter.create(loc, mNeg, posOne, negOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result. - Value xMinusN = rewriter.create(loc, x, n); - Value xMinusNDivM = rewriter.create(loc, xMinusN, m); - Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create(loc, n, m); + Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. Value nNeg = - rewriter.create(loc, LLVM::ICmpPredicate::slt, n, zero); - Value diffSign = - rewriter.create(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero); + Value diffSign = LLVM::ICmpOp::create(rewriter, loc, + LLVM::ICmpPredicate::ne, nNeg, mNeg); Value nNonZero = - rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create(loc, diffSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, negRes, posRes); return success(); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index 4821962f989e6..36cfe9dd6e2db 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mPos = rewriter.create(loc, m, zero); - Value x = rewriter.create(loc, mPos, negOne, posOne); + Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create(loc, n, x); - Value nPlusXDivM = rewriter.create(loc, nPlusX, m); - Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create(loc, zero, n); - Value negNDivM = rewriter.create(loc, negN, m); - Value negRes = rewriter.create(loc, zero, negNDivM); + Value negN = spirv::ISubOp::create(rewriter, loc, zero, n); + Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. - Value nPos = rewriter.create(loc, n, zero); - Value sameSign = rewriter.create(loc, nPos, mPos); - Value nNonZero = rewriter.create(loc, n, zero); - Value cmp = rewriter.create(loc, sameSign, nNonZero); + Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero); + Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } @@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value one = rewriter.create(loc, n_type, - IntegerAttr::get(n_type, 1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value one = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); // Compute the non-zero result. - Value minusOne = rewriter.create(loc, n, one); - Value quotient = rewriter.create(loc, minusOne, m); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); + Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one); // Pick the result - Value cmp = rewriter.create(loc, n, zero); + Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } @@ -197,32 +197,33 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create( - loc, n_type, IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mNeg = rewriter.create(loc, m, zero); - Value x = rewriter.create(loc, mNeg, posOne, negOne); + Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result - Value xMinusN = rewriter.create(loc, x, n); - Value xMinusNDivM = rewriter.create(loc, xMinusN, m); - Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create(loc, n, m); + Value posRes = spirv::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. - Value nNeg = rewriter.create(loc, n, zero); - Value diffSign = rewriter.create(loc, nNeg, mNeg); - Value nNonZero = rewriter.create(loc, n, zero); + Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero); + Value diffSign = + spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); - Value cmp = rewriter.create(loc, diffSign, nNonZero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index e34d5f74d232f..fce7a3f324b86 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -32,7 +32,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return MemRefDescriptor(descriptor); } @@ -99,21 +99,21 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create(loc, resultType, - builder.getIndexAttr(value)); + return LLVM::ConstantOp::create(builder, loc, resultType, + builder.getIndexAttr(value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { - return builder.create(loc, value, - kOffsetPosInMemRefDescriptor); + return LLVM::ExtractValueOp::create(builder, loc, value, + kOffsetPosInMemRefDescriptor); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { - value = builder.create(loc, value, offset, - kOffsetPosInMemRefDescriptor); + value = LLVM::InsertValueOp::create(builder, loc, value, offset, + kOffsetPosInMemRefDescriptor); } /// Builds IR inserting the offset into the descriptor. @@ -125,8 +125,9 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( - loc, value, ArrayRef({kSizePosInMemRefDescriptor, pos})); + return LLVM::ExtractValueOp::create( + builder, loc, value, + ArrayRef({kSizePosInMemRefDescriptor, pos})); } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, @@ -137,23 +138,25 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, // Copy size values to stack-allocated memory. auto one = createIndexAttrConstant(builder, loc, indexType, 1); - auto sizes = builder.create( - loc, value, llvm::ArrayRef({kSizePosInMemRefDescriptor})); - auto sizesPtr = builder.create(loc, ptrTy, arrayTy, one, - /*alignment=*/0); - builder.create(loc, sizes, sizesPtr); + auto sizes = LLVM::ExtractValueOp::create( + builder, loc, value, + llvm::ArrayRef({kSizePosInMemRefDescriptor})); + auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one, + /*alignment=*/0); + LLVM::StoreOp::create(builder, loc, sizes, sizesPtr); // Load an return size value of interest. - auto resultPtr = builder.create(loc, ptrTy, arrayTy, sizesPtr, - ArrayRef{0, pos}); - return builder.create(loc, indexType, resultPtr); + auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr, + ArrayRef{0, pos}); + return LLVM::LoadOp::create(builder, loc, indexType, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { - value = builder.create( - loc, value, size, ArrayRef({kSizePosInMemRefDescriptor, pos})); + value = LLVM::InsertValueOp::create( + builder, loc, value, size, + ArrayRef({kSizePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, @@ -164,15 +167,16 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( - loc, value, ArrayRef({kStridePosInMemRefDescriptor, pos})); + return LLVM::ExtractValueOp::create( + builder, loc, value, + ArrayRef({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { - value = builder.create( - loc, value, stride, + value = LLVM::InsertValueOp::create( + builder, loc, value, stride, ArrayRef({kStridePosInMemRefDescriptor, pos})); } @@ -207,8 +211,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, ? offset(builder, loc) : createIndexAttrConstant(builder, loc, indexType, offsetCst); Type elementType = converter.convertType(type.getElementType()); - ptr = builder.create(loc, ptr.getType(), elementType, ptr, - offsetVal); + ptr = LLVM::GEPOp::create(builder, loc, ptr.getType(), elementType, ptr, + offsetVal); return ptr; } @@ -303,7 +307,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { @@ -380,19 +384,19 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); Value doublePointerSize = - builder.create(loc, indexType, two, pointerSize); + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); - Value doubleRank = builder.create(loc, indexType, two, rank); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); Value doubleRankIncremented = - builder.create(loc, indexType, doubleRank, one); - Value rankIndexSize = builder.create( - loc, indexType, doubleRankIncremented, indexSize); + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, + doubleRankIncremented, indexSize); // Total allocation size. - Value allocationSize = builder.create( - loc, indexType, doublePointerSize, rankIndexSize); + Value allocationSize = LLVM::AddOp::create( + builder, loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } } @@ -400,13 +404,13 @@ void UnrankedMemRefDescriptor::computeSizes( Value UnrankedMemRefDescriptor::allocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { - return builder.create(loc, elemPtrType, memRefDescPtr); + return LLVM::LoadOp::create(builder, loc, elemPtrType, memRefDescPtr); } void UnrankedMemRefDescriptor::setAllocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { - builder.create(loc, allocatedPtr, memRefDescPtr); + LLVM::StoreOp::create(builder, loc, allocatedPtr, memRefDescPtr); } static std::pair @@ -423,9 +427,9 @@ Value UnrankedMemRefDescriptor::alignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef{1}); - return builder.create(loc, elemPtrType, alignedGep); + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef{1}); + return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep); } void UnrankedMemRefDescriptor::setAlignedPtr( @@ -435,9 +439,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef{1}); - builder.create(loc, alignedPtr, alignedGep); + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef{1}); + LLVM::StoreOp::create(builder, loc, alignedPtr, alignedGep); } Value UnrankedMemRefDescriptor::offsetBasePtr( @@ -446,8 +450,8 @@ Value UnrankedMemRefDescriptor::offsetBasePtr( auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); - return builder.create(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef{2}); + return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef{2}); } Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, @@ -456,8 +460,8 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVM::LLVMPointerType elemPtrType) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - return builder.create(loc, typeConverter.getIndexType(), - offsetPtr); + return LLVM::LoadOp::create(builder, loc, typeConverter.getIndexType(), + offsetPtr); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, @@ -467,7 +471,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - builder.create(loc, offset, offsetPtr); + LLVM::StoreOp::create(builder, loc, offset, offsetPtr); } Value UnrankedMemRefDescriptor::sizeBasePtr( @@ -477,8 +481,8 @@ Value UnrankedMemRefDescriptor::sizeBasePtr( Type structTy = LLVM::LLVMStructType::getLiteral( indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create(loc, resultType, structTy, memRefDescPtr, - ArrayRef{0, 3}); + return LLVM::GEPOp::create(builder, loc, resultType, structTy, memRefDescPtr, + ArrayRef{0, 3}); } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, @@ -489,8 +493,8 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create(loc, ptrType, indexTy, sizeBasePtr, index); - return builder.create(loc, indexTy, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep); } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, @@ -501,8 +505,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create(loc, ptrType, indexTy, sizeBasePtr, index); - builder.create(loc, size, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + LLVM::StoreOp::create(builder, loc, size, sizeStoreGep); } Value UnrankedMemRefDescriptor::strideBasePtr( @@ -511,7 +515,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr( Type indexTy = typeConverter.getIndexType(); auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create(loc, ptrType, indexTy, sizeBasePtr, rank); + return LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, rank); } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, @@ -522,8 +526,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create(loc, ptrType, indexTy, strideBasePtr, index); - return builder.create(loc, indexTy, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep); } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, @@ -534,6 +538,6 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create(loc, ptrType, indexTy, strideBasePtr, index); - builder.create(loc, stride, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + LLVM::StoreOp::create(builder, loc, stride, strideStoreGep); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index c5f72f7e10b8c..ecd5b6367fba4 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -57,8 +57,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create(loc, resultType, - builder.getIndexAttr(value)); + return LLVM::ConstantOp::create(builder, loc, resultType, + builder.getIndexAttr(value)); } Value ConvertToLLVMPattern::getStridedElementPtr( @@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( runningStride = sizes[i]; else if (stride == ShapedType::kDynamic) runningStride = - rewriter.create(loc, runningStride, sizes[i]); + LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]); else runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); } @@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( // Buffer size in bytes. Type elementType = typeConverter->convertType(memRefType.getElementType()); auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, elementType, nullPtr, runningStride); - size = rewriter.create(loc, getIndexType(), gepPtr); + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType, + elementType, nullPtr, runningStride); + size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); } else { size = runningStride; } @@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes( // which is a common pattern of getting the size of a type in bytes. Type llvmType = typeConverter->convertType(type); auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto nullPtr = rewriter.create(loc, convertedPtrType); - auto gep = rewriter.create(loc, convertedPtrType, llvmType, - nullPtr, ArrayRef{1}); - return rewriter.create(loc, getIndexType(), gep); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType); + auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType, + nullPtr, ArrayRef{1}); + return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( @@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements( staticSize == ShapedType::kDynamic ? dynamicSizes[dynamicIndex++] : createIndexAttrConstant(rewriter, loc, indexType, staticSize); - numElements = rewriter.create(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } else { numElements = staticSize == ShapedType::kDynamic @@ -276,14 +276,14 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( ? builder .create(loc, mallocFunc.value(), allocationSize) .getResult() - : builder.create(loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); - builder.create(loc, memory, source, allocationSize, false); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) - builder.create(loc, freeFunc.value(), source); + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks @@ -349,8 +349,8 @@ LogicalResult LLVM::detail::oneToOneRewrite( SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( - op->getLoc(), newOp->getResult(0), i)); + results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(), + newOp->getResult(0), i)); } rewriter.replaceOp(op, results); return success(); @@ -371,8 +371,8 @@ LogicalResult LLVM::detail::intrinsicRewrite( if (numResults != 0) resType = typeConverter.packOperationResults(op->getResultTypes()); - auto callIntrOp = rewriter.create( - loc, resType, rewriter.getStringAttr(intrinsic), operands); + auto callIntrOp = LLVM::CallIntrinsicOp::create( + rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands); // Propagate attributes. callIntrOp->setAttrs(op->getAttrDictionary()); @@ -388,7 +388,7 @@ LogicalResult LLVM::detail::intrinsicRewrite( results.reserve(numResults); Value intrRes = callIntrOp.getResults(); for (unsigned i = 0; i < numResults; ++i) - results.push_back(rewriter.create(loc, intrRes, i)); + results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i)); rewriter.replaceOp(op, results); return success(); @@ -406,7 +406,7 @@ static unsigned getBitWidth(Type type) { static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value) { Type i32 = builder.getI32Type(); - return builder.create(loc, i32, value); + return LLVM::ConstantOp::create(builder, loc, i32, value); } SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, @@ -418,17 +418,17 @@ SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, unsigned srcBitWidth = getBitWidth(srcType); unsigned dstBitWidth = getBitWidth(dstType); if (srcBitWidth == dstBitWidth) { - Value cast = builder.create(loc, dstType, src); + Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src); return {cast}; } if (dstBitWidth > srcBitWidth) { auto smallerInt = builder.getIntegerType(srcBitWidth); if (srcType != smallerInt) - src = builder.create(loc, smallerInt, src); + src = LLVM::BitcastOp::create(builder, loc, smallerInt, src); auto largerInt = builder.getIntegerType(dstBitWidth); - Value res = builder.create(loc, largerInt, src); + Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src); return {res}; } assert(srcBitWidth % dstBitWidth == 0 && @@ -436,12 +436,12 @@ SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, int64_t numElements = srcBitWidth / dstBitWidth; auto vecType = VectorType::get(numElements, dstType); - src = builder.create(loc, vecType, src); + src = LLVM::BitcastOp::create(builder, loc, vecType, src); SmallVector res; for (auto i : llvm::seq(numElements)) { Value idx = createI32Constant(builder, loc, i); - Value elem = builder.create(loc, src, idx); + Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx); res.emplace_back(elem); } @@ -461,28 +461,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, if (dstBitWidth < srcBitWidth) { auto largerInt = builder.getIntegerType(srcBitWidth); if (res.getType() != largerInt) - res = builder.create(loc, largerInt, res); + res = LLVM::BitcastOp::create(builder, loc, largerInt, res); auto smallerInt = builder.getIntegerType(dstBitWidth); - res = builder.create(loc, smallerInt, res); + res = LLVM::TruncOp::create(builder, loc, smallerInt, res); } if (res.getType() != dstType) - res = builder.create(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } int64_t numElements = src.size(); auto srcType = VectorType::get(numElements, src.front().getType()); - Value res = builder.create(loc, srcType); + Value res = LLVM::PoisonOp::create(builder, loc, srcType); for (auto &&[i, elem] : llvm::enumerate(src)) { Value idx = createI32Constant(builder, loc, i); - res = builder.create(loc, srcType, res, elem, idx); + res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx); } if (res.getType() != dstType) - res = builder.create(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } @@ -518,20 +518,20 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, Value stride = ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(builder, loc, i) - : builder.create( - loc, indexType, builder.getIndexAttr(strides[i])); - increment = - builder.create(loc, increment, stride, intOverflowFlags); + : LLVM::ConstantOp::create(builder, loc, indexType, + builder.getIndexAttr(strides[i])); + increment = LLVM::MulOp::create(builder, loc, increment, stride, + intOverflowFlags); } - index = index ? builder.create(loc, index, increment, - intOverflowFlags) + index = index ? LLVM::AddOp::create(builder, loc, index, increment, + intOverflowFlags) : increment; } Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? builder.create( - loc, elementPtrType, - converter.convertType(type.getElementType()), base, index, - noWrapFlags) - : base; + return index + ? LLVM::GEPOp::create(builder, loc, elementPtrType, + converter.convertType(type.getElementType()), + base, index, noWrapFlags) + : base; } diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 49c73fbc9dd79..d95aeba8a4488 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -66,23 +66,23 @@ LogicalResult mlir::LLVM::createPrintStrCall( DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); auto arrayTy = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - auto globalOp = builder.create( - loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, + auto globalOp = LLVM::GlobalOp::create( + builder, loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); // Emit call to `printStr` in runtime library. builder.restoreInsertionPoint(ip); auto msgAddr = - builder.create(loc, ptrTy, globalOp.getName()); + LLVM::AddressOfOp::create(builder, loc, ptrTy, globalOp.getName()); SmallVector indices(1, 0); Value gep = - builder.create(loc, ptrTy, arrayTy, msgAddr, indices); + LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, msgAddr, indices); FailureOr printer = LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName); if (failed(printer)) return failure(); - builder.create(loc, TypeRange(), - SymbolRefAttr::get(printer.value()), gep); + LLVM::CallOp::create(builder, loc, TypeRange(), + SymbolRefAttr::get(printer.value()), gep); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp index 1cd0bd85f9894..13ed4628c3c9e 100644 --- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp @@ -24,10 +24,10 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) const { - return builder.create(loc, value, pos); + return LLVM::ExtractValueOp::create(builder, loc, value, pos); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { - value = builder.create(loc, value, ptr, pos); + value = LLVM::InsertValueOp::create(builder, loc, value, ptr, pos); } diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 7312594c761f7..1a9bf569086da 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -91,7 +91,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder, packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -107,7 +107,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder, packRankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -224,12 +224,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); @@ -731,12 +731,12 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create(loc, builder.getI64Type(), - builder.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(1)); Value allocated = - builder.create(loc, ptrType, operand.getType(), one); + LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one); // Store into the alloca'ed descriptor. - builder.create(loc, operand, allocated); + LLVM::StoreOp::create(builder, loc, operand, allocated); return allocated; } diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index bf3f31729c3da..e7dd0b506e12d 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -87,17 +87,17 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; auto loc = op->getLoc(); - Value desc = rewriter.create(loc, resultNDVectoryTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy); nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (const auto &operand : llvm::enumerate(operands)) { - extractedOperands.push_back(rewriter.create( - loc, operand.value(), position)); + extractedOperands.push_back(LLVM::ExtractValueOp::create( + rewriter, loc, operand.value(), position)); } Value newVal = createOperand(result1DVectorTy, extractedOperands); - desc = rewriter.create(loc, desc, newVal, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index c3f213147b7a7..3f4b4d6cbc8ab 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -78,8 +78,8 @@ getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { // Insert before module terminator. rewriter.setInsertionPoint(module.getBody(), std::prev(module.getBody()->end())); - func::FuncOp funcOp = rewriter.create( - op->getLoc(), fnNameAttr.getValue(), libFnType); + func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(), + fnNameAttr.getValue(), libFnType); // Insert a function attribute that will trigger the emission of the // corresponding `_mlir_ciface_xxx` interface so that external libraries see // a normalized ABI. This interface is added during std to llvm conversion. @@ -100,8 +100,8 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, res.push_back(op); continue; } - Value cast = - b.create(loc, makeStridedLayoutDynamic(memrefType), op); + Value cast = memref::CastOp::create( + b, loc, makeStridedLayoutDynamic(memrefType), op); res.push_back(cast); } return res; diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index d4deff5b88070..5b68eb8188996 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -54,18 +54,18 @@ std::pair getRawPtrAndSize(const Location loc, Value memRef, Type elType) { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); Value dataPtr = - rewriter.create(loc, ptrType, memRef, 1); - Value offset = rewriter.create( - loc, rewriter.getI64Type(), memRef, 2); + LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1); + Value offset = LLVM::ExtractValueOp::create(rewriter, loc, + rewriter.getI64Type(), memRef, 2); Value resPtr = - rewriter.create(loc, ptrType, elType, dataPtr, offset); + LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset); Value size; if (cast(memRef.getType()).getBody().size() > 3) { - size = rewriter.create(loc, memRef, - ArrayRef{3, 0}); - size = rewriter.create(loc, rewriter.getI32Type(), size); + size = LLVM::ExtractValueOp::create(rewriter, loc, memRef, + ArrayRef{3, 0}); + size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size); } else { - size = rewriter.create(loc, 1, 32); + size = arith::ConstantIntOp::create(rewriter, loc, 1, 32); } return {resPtr, size}; } @@ -157,13 +157,13 @@ class MPICHImplTraits : public MPIImplTraits { Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) override { static constexpr int MPI_COMM_WORLD = 0x44000000; - return rewriter.create(loc, rewriter.getI64Type(), - MPI_COMM_WORLD); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + MPI_COMM_WORLD); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create(loc, rewriter.getI32Type(), comm); + return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm); } intptr_t getStatusIgnore() override { return 1; } @@ -195,7 +195,8 @@ class MPICHImplTraits : public MPIImplTraits { mtype = MPI_UINT8_T; else assert(false && "unsupported type"); - return rewriter.create(loc, rewriter.getI32Type(), mtype); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + mtype); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -245,7 +246,7 @@ class MPICHImplTraits : public MPIImplTraits { op = MPI_REPLACE; break; } - return rewriter.create(loc, rewriter.getI32Type(), op); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op); } }; @@ -281,16 +282,16 @@ class OMPIImplTraits : public MPIImplTraits { getOrDefineExternalStruct(loc, rewriter, name, commStructT); // get address of symbol - auto comm = rewriter.create( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, name)); - return rewriter.create(loc, rewriter.getI64Type(), comm); + auto comm = LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, name)); + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); + return LLVM::IntToPtrOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); } intptr_t getStatusIgnore() override { return 0; } @@ -330,9 +331,9 @@ class OMPIImplTraits : public MPIImplTraits { // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); // get address of symbol - return rewriter.create( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, mtype)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, mtype)); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -389,9 +390,9 @@ class OMPIImplTraits : public MPIImplTraits { // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, op, opStructT); // get address of symbol - return rewriter.create( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, op)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, op)); } }; @@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` - auto nullPtrOp = rewriter.create(loc, ptrType); + auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType); Value llvmnull = nullPtrOp.getRes(); // grab a reference to the global module op: @@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { // get communicator Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - auto one = rewriter.create(loc, i32, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); auto outPtr = - rewriter.create(loc, ptrType, comm.getType(), one); + LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one); // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) auto funcType = @@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Comm_split", funcType); - auto callOp = rewriter.create( - loc, funcDecl, - ValueRange{comm, adaptor.getColor(), adaptor.getKey(), - outPtr.getRes()}); + auto callOp = + LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{comm, adaptor.getColor(), + adaptor.getKey(), outPtr.getRes()}); // load the communicator into a register - Value res = rewriter.create(loc, i32, outPtr.getResult()); - res = rewriter.create(loc, rewriter.getI64Type(), res); + Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult()); + res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res); // if retval is checked, replace uses of retval with the results from the // call op @@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern { moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); // replace with function call - auto one = rewriter.create(loc, i32, 1); - auto rankptr = rewriter.create(loc, ptrType, i32, one); - auto callOp = rewriter.create( - loc, initDecl, ValueRange{comm, rankptr.getRes()}); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); + auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one); + auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl, + ValueRange{comm, rankptr.getRes()}); // load the rank into a register auto loadedRank = - rewriter.create(loc, i32, rankptr.getResult()); + LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult()); // if retval is checked, replace uses of retval with the results from the // call op @@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType); // replace op with function call - auto funcCall = rewriter.create( - loc, funcDecl, - ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), - comm}); + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{dataPtr, size, dataType, + adaptor.getDest(), + adaptor.getTag(), comm}); if (op.getRetval()) rewriter.replaceOp(op, funcCall.getResult()); else @@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern { auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - Value statusIgnore = rewriter.create( - loc, i64, mpiTraits->getStatusIgnore()); + Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64, + mpiTraits->getStatusIgnore()); statusIgnore = - rewriter.create(loc, ptrType, statusIgnore); + LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore); // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, // tag, comm)` @@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType); // replace op with function call - auto funcCall = rewriter.create( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getSource(), adaptor.getTag(), comm, statusIgnore}); if (op.getRetval()) @@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { // If input and output are the same, request in-place operation. if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { - sendPtr = rewriter.create( - loc, i64, reinterpret_cast(mpiTraits->getInPlace())); - sendPtr = rewriter.create(loc, ptrType, sendPtr); + sendPtr = LLVM::ConstantOp::create( + rewriter, loc, i64, + reinterpret_cast(mpiTraits->getInPlace())); + sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr); } Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); @@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType); // replace op with function call - auto funcCall = rewriter.create( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); if (op.getRetval()) diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 7f4655e53609e..08a456691880c 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -121,19 +121,19 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); - Value result = rewriter.create( - loc, DenseElementsAttr::get(vecType, initValueAttr)); + Value result = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr)); SmallVector strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(linearIndex, strides); SmallVector operands; for (Value input : op->getOperands()) operands.push_back( - rewriter.create(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, result); return success(); @@ -195,7 +195,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { FunctionType funcType = FunctionType::get( builder.getContext(), {elementType, elementType}, elementType); - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -208,12 +208,12 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value zeroValue = builder.create( - elementType, builder.getIntegerAttr(elementType, 0)); - Value oneValue = builder.create( - elementType, builder.getIntegerAttr(elementType, 1)); - Value minusOneValue = builder.create( - elementType, + Value zeroValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 0)); + Value oneValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 1)); + Value minusOneValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, APInt(elementType.getIntOrFloatBitWidth(), -1ULL, /*isSigned=*/true))); @@ -221,82 +221,83 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { // if (p == T(0)) // return T(1); auto pIsZero = - builder.create(arith::CmpIPredicate::eq, pArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); // if (p < T(0)) { builder.setInsertionPointToEnd(fallthroughBlock); - auto pIsNeg = - builder.create(arith::CmpIPredicate::sle, pArg, zeroValue); + auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, + zeroValue); // if (b == T(0)) builder.createBlock(funcBody); auto bIsZero = - builder.create(arith::CmpIPredicate::eq, bArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue); // return T(1) / T(0); thenBlock = builder.createBlock(funcBody); - builder.create( - builder.create(oneValue, zeroValue).getResult()); + func::ReturnOp::create( + builder, + arith::DivSIOp::create(builder, oneValue, zeroValue).getResult()); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(0)). builder.setInsertionPointToEnd(bIsZero->getBlock()); - builder.create(bIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock); // if (b == T(1)) builder.setInsertionPointToEnd(fallthroughBlock); auto bIsOne = - builder.create(arith::CmpIPredicate::eq, bArg, oneValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue); // return T(1); thenBlock = builder.createBlock(funcBody); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(1)). builder.setInsertionPointToEnd(bIsOne->getBlock()); - builder.create(bIsOne, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock); // if (b == T(-1)) { builder.setInsertionPointToEnd(fallthroughBlock); - auto bIsMinusOne = builder.create(arith::CmpIPredicate::eq, - bArg, minusOneValue); + auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + bArg, minusOneValue); // if (p & T(1)) builder.createBlock(funcBody); - auto pIsOdd = builder.create( - arith::CmpIPredicate::ne, builder.create(pArg, oneValue), - zeroValue); + auto pIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, pArg, oneValue), zeroValue); // return T(-1); thenBlock = builder.createBlock(funcBody); - builder.create(minusOneValue); + func::ReturnOp::create(builder, minusOneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(pIsOdd->getBlock()); - builder.create(pIsOdd, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock); // return T(1); // } // b == T(-1) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(-1)). builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); - builder.create(bIsMinusOne, pIsOdd->getBlock(), - fallthroughBlock); + cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(), + fallthroughBlock); // return T(0); // } // (p < T(0)) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(zeroValue); + func::ReturnOp::create(builder, zeroValue); Block *loopHeader = builder.createBlock( funcBody, funcBody->end(), {elementType, elementType, elementType}, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set up conditional branch for (p < T(0)). builder.setInsertionPointToEnd(pIsNeg->getBlock()); // Set initial values of 'result', 'b' and 'p' for the loop. - builder.create(pIsNeg, bIsZero->getBlock(), loopHeader, - ValueRange{oneValue, bArg, pArg}); + cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader, + ValueRange{oneValue, bArg, pArg}); // T result = T(1); // while (true) { @@ -313,45 +314,46 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { builder.setInsertionPointToEnd(loopHeader); // if (p & T(1)) - auto powerTmpIsOdd = builder.create( - arith::CmpIPredicate::ne, - builder.create(powerTmp, oneValue), zeroValue); + auto powerTmpIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create(resultTmp, baseTmp); + Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, - resultTmp); + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, + resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= T(1); builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create(powerTmp, oneValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue); // if (p == T(0)) - auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, - newPowerTmp, zeroValue); + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + newPowerTmp, zeroValue); // return result; thenBlock = builder.createBlock(funcBody); - builder.create(newResultTmp); + func::ReturnOp::create(builder, newResultTmp); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create(newPowerIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock, + fallthroughBlock); // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create(baseTmp, baseTmp); + Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create( - ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + cf::BranchOp::create( + builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); return funcOp; } @@ -420,7 +422,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, llvm::raw_string_ostream nameOS(funcName); nameOS << '_' << baseType; nameOS << '_' << powType; - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -433,46 +435,48 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value oneBValue = builder.create( - baseType, builder.getFloatAttr(baseType, 1.0)); - Value zeroPValue = builder.create( - powType, builder.getIntegerAttr(powType, 0)); - Value onePValue = builder.create( - powType, builder.getIntegerAttr(powType, 1)); - Value minPValue = builder.create( - powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue( - powType.getWidth()))); - Value maxPValue = builder.create( - powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue( - powType.getWidth()))); + Value oneBValue = arith::ConstantOp::create( + builder, baseType, builder.getFloatAttr(baseType, 1.0)); + Value zeroPValue = arith::ConstantOp::create( + builder, powType, builder.getIntegerAttr(powType, 0)); + Value onePValue = arith::ConstantOp::create( + builder, powType, builder.getIntegerAttr(powType, 1)); + Value minPValue = arith::ConstantOp::create( + builder, powType, + builder.getIntegerAttr( + powType, llvm::APInt::getSignedMinValue(powType.getWidth()))); + Value maxPValue = arith::ConstantOp::create( + builder, powType, + builder.getIntegerAttr( + powType, llvm::APInt::getSignedMaxValue(powType.getWidth()))); // if (p == Tp{0}) // return Tb{1}; - auto pIsZero = - builder.create(arith::CmpIPredicate::eq, pArg, zeroPValue); + auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, + zeroPValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create(oneBValue); + func::ReturnOp::create(builder, oneBValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == Tp{0}). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); builder.setInsertionPointToEnd(fallthroughBlock); // bool isNegativePower{p < Tp{0}} - auto pIsNeg = builder.create(arith::CmpIPredicate::sle, pArg, - zeroPValue); + auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, + zeroPValue); // bool isMin{p == std::numeric_limits::min()}; auto pIsMin = - builder.create(arith::CmpIPredicate::eq, pArg, minPValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue); // if (isMin) { // p = std::numeric_limits::max(); // } else if (isNegativePower) { // p = -p; // } - Value negP = builder.create(zeroPValue, pArg); - auto pInit = builder.create(pIsNeg, negP, pArg); - pInit = builder.create(pIsMin, maxPValue, pInit); + Value negP = arith::SubIOp::create(builder, zeroPValue, pArg); + auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg); + pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit); // Tb result = Tb{1}; // Tb origBase = Tb{b}; @@ -489,7 +493,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set initial values of 'result', 'b' and 'p' for the loop. builder.setInsertionPointToEnd(pInit->getBlock()); - builder.create(loopHeader, ValueRange{oneBValue, bArg, pInit}); + cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit}); // Create loop body. Value resultTmp = loopHeader->getArgument(0); @@ -498,30 +502,30 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, builder.setInsertionPointToEnd(loopHeader); // if (p & Tp{1}) - auto powerTmpIsOdd = builder.create( - arith::CmpIPredicate::ne, - builder.create(powerTmp, onePValue), zeroPValue); + auto powerTmpIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create(resultTmp, baseTmp); + Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & Tp{1}). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, - resultTmp); + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, + resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= Tp{1}; builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create(powerTmp, onePValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue); // if (p == Tp{0}) - auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, - newPowerTmp, zeroPValue); + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + newPowerTmp, zeroPValue); // break; // // The conditional branch is finalized below with a jump to @@ -531,10 +535,10 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create(baseTmp, baseTmp); + Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create( - ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + cf::BranchOp::create( + builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); // Set up conditional branch for early loop exit: // if (p == Tp{0}) @@ -542,8 +546,8 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create(newPowerIsZero, loopExit, newResultTmp, - fallthroughBlock, ValueRange{}); + cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp, + fallthroughBlock, ValueRange{}); // if (isMin) { // result *= origBase; @@ -553,11 +557,11 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(loopExit); - builder.create(pIsMin, thenBlock, fallthroughBlock, - newResultTmp); + cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock, + newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create(newResultTmp, bArg); - builder.create(newResultTmp, fallthroughBlock); + newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); /// if (isNegativePower) { /// result = Tb{1} / result; @@ -567,15 +571,15 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(pIsNeg, thenBlock, returnBlock, - newResultTmp); + cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock, + newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create(oneBValue, newResultTmp); - builder.create(newResultTmp, returnBlock); + newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp); + cf::BranchOp::create(builder, newResultTmp, returnBlock); // return result; builder.setInsertionPointToEnd(returnBlock); - builder.create(returnBlock->getArgument(0)); + func::ReturnOp::create(builder, returnBlock->getArgument(0)); return funcOp; } @@ -667,7 +671,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { nameOS << '_' << elementType; FunctionType funcType = FunctionType::get(builder.getContext(), {elementType}, elementType); - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); // LinkonceODR ensures that there is only one implementation of this function // across all math.ctlz functions that are lowered in this way. @@ -683,33 +687,34 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value arg = funcOp.getArgument(0); Type indexType = builder.getIndexType(); - Value bitWidthValue = builder.create( - elementType, builder.getIntegerAttr(elementType, bitWidth)); - Value zeroValue = builder.create( - elementType, builder.getIntegerAttr(elementType, 0)); + Value bitWidthValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, bitWidth)); + Value zeroValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 0)); Value inputEqZero = - builder.create(arith::CmpIPredicate::eq, arg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue); // if input == 0, return bit width, else enter loop. - scf::IfOp ifOp = builder.create( - elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); + scf::IfOp ifOp = + scf::IfOp::create(builder, elementType, inputEqZero, + /*addThenBlock=*/true, /*addElseBlock=*/true); ifOp.getThenBodyBuilder().create(loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); - Value oneIndex = elseBuilder.create( - indexType, elseBuilder.getIndexAttr(1)); - Value oneValue = elseBuilder.create( - elementType, elseBuilder.getIntegerAttr(elementType, 1)); - Value bitWidthIndex = elseBuilder.create( - indexType, elseBuilder.getIndexAttr(bitWidth)); - Value nValue = elseBuilder.create( - elementType, elseBuilder.getIntegerAttr(elementType, 0)); - - auto loop = elseBuilder.create( - oneIndex, bitWidthIndex, oneIndex, + Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType, + elseBuilder.getIndexAttr(1)); + Value oneValue = arith::ConstantOp::create( + elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1)); + Value bitWidthIndex = arith::ConstantOp::create( + elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth)); + Value nValue = arith::ConstantOp::create( + elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0)); + + auto loop = scf::ForOp::create( + elseBuilder, oneIndex, bitWidthIndex, oneIndex, // Initial values for two loop induction variables, the arg which is being // shifted left in each iteration, and the n value which tracks the count // of leading zeros. @@ -725,25 +730,25 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value argIter = args[0]; Value nIter = args[1]; - Value argIsNonNegative = b.create( - loc, arith::CmpIPredicate::slt, argIter, zeroValue); - scf::IfOp ifOp = b.create( - loc, argIsNonNegative, + Value argIsNonNegative = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, argIter, zeroValue); + scf::IfOp ifOp = scf::IfOp::create( + b, loc, argIsNonNegative, [&](OpBuilder &b, Location loc) { // If arg is negative, continue (effectively, break) - b.create(loc, ValueRange{argIter, nIter}); + scf::YieldOp::create(b, loc, ValueRange{argIter, nIter}); }, [&](OpBuilder &b, Location loc) { // Otherwise, increment n and shift arg left. - Value nNext = b.create(loc, nIter, oneValue); - Value argNext = b.create(loc, argIter, oneValue); - b.create(loc, ValueRange{argNext, nNext}); + Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue); + Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue); + scf::YieldOp::create(b, loc, ValueRange{argNext, nNext}); }); - b.create(loc, ifOp.getResults()); + scf::YieldOp::create(b, loc, ifOp.getResults()); }); - elseBuilder.create(loop.getResult(1)); + scf::YieldOp::create(elseBuilder, loop.getResult(1)); - builder.create(ifOp.getResult(0)); + func::ReturnOp::create(builder, ifOp.getResult(0)); return funcOp; } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index f4d69ce8235bb..853f45498ac52 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), typeConverter, [&](Type llvm1DVectorTy, ValueRange operands) { - return rewriter.create(loc, llvm1DVectorTy, operands[0], - false); + return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0], + false); }, rewriter); } @@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(llvmOperandType)) { - one = rewriter.create( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)); } else { - one = rewriter.create(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto exp = rewriter.create(loc, adaptor.getOperand(), - expAttrs.getAttrs()); + auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(), + expAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); @@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto exp = rewriter.create( - loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); - return rewriter.create( - loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], expAttrs.getAttrs()); + return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{exp, one}, + subAttrs.getAttrs()); }, rewriter); } @@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one = isa(llvmOperandType) - ? rewriter.create( - loc, llvmOperandType, + ? LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)) - : rewriter.create(loc, llvmOperandType, - floatOne); + : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, + floatOne); - auto add = rewriter.create( - loc, llvmOperandType, ValueRange{one, adaptor.getOperand()}, - addAttrs.getAttrs()); + auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType, + ValueRange{one, adaptor.getOperand()}, + addAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs()); return success(); @@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create(loc, llvm1DVectorTy, - ValueRange{one, operands[0]}, - addAttrs.getAttrs()); - return rewriter.create( - loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, operands[0]}, + addAttrs.getAttrs()); + return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } @@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one; if (isa(llvmOperandType)) { - one = rewriter.create( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)); } else { - one = rewriter.create(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto sqrt = rewriter.create(loc, adaptor.getOperand(), - sqrtAttrs.getAttrs()); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(), + sqrtAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); @@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto sqrt = rewriter.create( - loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); - return rewriter.create( - loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], sqrtAttrs.getAttrs()); + return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, sqrt}, + divAttrs.getAttrs()); }, rewriter); } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index a0ce7d3b75fc2..f7c0d4fe3a799 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -84,20 +84,21 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); - Value result = rewriter.create( - loc, DenseElementsAttr::get( - vecType, FloatAttr::get(vecType.getElementType(), 0.0))); + Value result = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(vecType, + FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector strides = computeStrides(shape); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(linearIndex, strides); SmallVector operands; for (auto input : op->getOperands()) operands.push_back( - rewriter.create(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); @@ -114,9 +115,9 @@ PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto f32 = rewriter.getF32Type(); auto extendedOperands = llvm::to_vector( llvm::map_range(op->getOperands(), [&](Value operand) -> Value { - return rewriter.create(loc, f32, operand); + return arith::ExtFOp::create(rewriter, loc, f32, operand); })); - auto newOp = rewriter.create(loc, f32, extendedOperands); + auto newOp = Op::create(rewriter, loc, f32, extendedOperands); rewriter.replaceOpWithNewOp(op, opType, newOp); return success(); } @@ -139,8 +140,8 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), name, - opFunctionTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, + opFunctionTy); opFunc.setPrivate(); // By definition Math dialect operations imply LLVM's "readnone" diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 59db14ed816be..a877ad21734a2 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value, if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector values(vectorType.getNumElements(), value); - return builder.create(loc, type, - builder.getI32VectorAttr(values)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32VectorAttr(values)); } if (type.isInteger(32)) - return builder.create(loc, type, - builder.getI32IntegerAttr(value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32IntegerAttr(value)); return nullptr; } @@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern { Type intType = rewriter.getIntegerType(bitwidth); uint64_t intValue = uint64_t(1) << (bitwidth - 1); - Value signMask = rewriter.create( - loc, intType, rewriter.getIntegerAttr(intType, intValue)); - Value valueMask = rewriter.create( - loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); + Value signMask = spirv::ConstantOp::create( + rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue)); + Value valueMask = spirv::ConstantOp::create( + rewriter, loc, intType, + rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = dyn_cast(type)) { assert(vectorType.getRank() == 1); @@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern { intType = VectorType::get(count, intType); SmallVector signSplat(count, signMask); - signMask = - rewriter.create(loc, intType, signSplat); + signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + signSplat); SmallVector valueSplat(count, valueMask); - valueMask = rewriter.create(loc, intType, - valueSplat); + valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + valueSplat); } Value lhsCast = - rewriter.create(loc, intType, adaptor.getLhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs()); Value rhsCast = - rewriter.create(loc, intType, adaptor.getRhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs()); - Value value = rewriter.create( - loc, intType, ValueRange{lhsCast, valueMask}); - Value sign = rewriter.create( - loc, intType, ValueRange{rhsCast, signMask}); + Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{lhsCast, valueMask}); + Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{rhsCast, signMask}); - Value result = rewriter.create(loc, intType, - ValueRange{value, sign}); + Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType, + ValueRange{value, sign}); rewriter.replaceOpWithNewOp(copySignOp, type, result); return success(); } @@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); - Value msb = rewriter.create(loc, input); + Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input); // We need to subtract from 31 given that the index returned by GLSL // FindUMsb is counted from the least significant bit. Theoretically this // also gives the correct result even if the integer has all zero bits, in // which case GL FindUMsb would return -1. - Value subMsb = rewriter.create(loc, val31, msb); + Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb); // However, certain Vulkan implementations have driver bugs for the corner // case where the input is zero. And.. it can be smart to optimize a select // only involving the corner case. So separately compute the result when the // input is either zero or one. - Value subInput = rewriter.create(loc, val32, input); - Value cmp = rewriter.create(loc, input, val1); + Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input); + Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1); rewriter.replaceOpWithNewOp(countOp, cmp, subInput, subMsb); return success(); @@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern { if (!type) return failure(); - Value exp = rewriter.create(loc, type, adaptor.getOperand()); + Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); @@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern { auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); Value onePlus = - rewriter.create(loc, one, adaptor.getOperand()); + spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } @@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern { auto getConstantValue = [&](double value) { if (auto floatType = dyn_cast(type)) { - return rewriter.create( - loc, type, rewriter.getFloatAttr(floatType, value)); + return spirv::ConstantOp::create( + rewriter, loc, type, rewriter.getFloatAttr(floatType, value)); } if (auto vectorType = dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (isa(elemType)) { - return rewriter.create( - loc, type, + return spirv::ConstantOp::create( + rewriter, loc, type, DenseFPElementsAttr::get( vectorType, FloatAttr::get(elemType, value).getValue())); } @@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern { Value constantValue = getConstantValue( std::is_same() ? log2Reciprocal : log10Reciprocal); - Value log = rewriter.create(loc, adaptor.getOperand()); + Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, log, constantValue); return success(); @@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern { Location loc = powfOp.getLoc(); Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); Value lessThan = - rewriter.create(loc, adaptor.getLhs(), zero); + spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero); // Per C/C++ spec: // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is @@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern { // Calculate the reminder from the exponent and check whether it is zero. Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter); Value expRem = - rewriter.create(loc, adaptor.getRhs(), floatOne); + spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne); Value expRemNonZero = - rewriter.create(loc, expRem, zero); + spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero); Value cmpNegativeWithFractionalExp = - rewriter.create(loc, expRemNonZero, lessThan); + spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan); // Create NaN result and replace base value if conditions are met. const auto &floatSemantics = scalarFloatType.getFloatSemantics(); const auto nan = APFloat::getNaN(floatSemantics); @@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern { nanAttr = DenseElementsAttr::get(vectorType, nan); Value NanValue = - rewriter.create(loc, operandType, nanAttr); - Value lhs = rewriter.create( - loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs()); - Value abs = rewriter.create(loc, lhs); + spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); + Value lhs = + spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, + NanValue, adaptor.getLhs()); + Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in // order to properly propagate the sign, assuming integer y cases. It @@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern { // Cast exponent to integer and calculate exponent % 2 != 0. Value intRhs = - rewriter.create(loc, intType, adaptor.getRhs()); + spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs()); Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); Value bitwiseAndOne = - rewriter.create(loc, intRhs, intOne); - Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); + spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne); + Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne); // calculate pow based on abs(lhs)^rhs. - Value pow = rewriter.create(loc, abs, adaptor.getRhs()); - Value negate = rewriter.create(loc, pow); + Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs()); + Value negate = spirv::FNegateOp::create(rewriter, loc, pow); // if the exponent is odd and lhs < 0, negate the result. Value shouldNegate = - rewriter.create(loc, lessThan, isOdd); + spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, pow); return success(); @@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern { auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; if (VectorType vty = dyn_cast(ty)) { - half = rewriter.create( - loc, vty, + half = spirv::ConstantOp::create( + rewriter, loc, vty, DenseElementsAttr::get(vty, rewriter.getFloatAttr(ety, 0.5).getValue())); } else { - half = rewriter.create( - loc, ty, rewriter.getFloatAttr(ety, 0.5)); + half = spirv::ConstantOp::create(rewriter, loc, ty, + rewriter.getFloatAttr(ety, 0.5)); } - auto abs = rewriter.create(loc, operand); - auto floor = rewriter.create(loc, abs); - auto sub = rewriter.create(loc, abs, floor); + auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand); + auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); + auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); auto greater = - rewriter.create(loc, sub, half); - auto select = rewriter.create(loc, greater, one, zero); - auto add = rewriter.create(loc, floor, select); + spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); + auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); + auto add = spirv::FAddOp::create(rewriter, loc, floor, select); rewriter.replaceOpWithNewOp(roundOp, add, operand); return success(); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 0b7ffa40ec09d..e882845d9d99a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -160,8 +160,8 @@ struct ConvertGetGlobal final if (opTy.getRank() == 0) { emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); - emitc::GetGlobalOp globalLValue = rewriter.create( - op.getLoc(), lvalueType, operands.getNameAttr()); + emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( + rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); rewriter.replaceOpWithNewOp( op, pointerType, rewriter.getStringAttr("&"), globalLValue); @@ -191,8 +191,8 @@ struct ConvertLoad final : public OpConversionPattern { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, resultTy, subscript); return success(); @@ -211,8 +211,8 @@ struct ConvertStore final : public OpConversionPattern { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, subscript, operands.getValue()); return success(); @@ -242,7 +242,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { if (inputs.size() != 1) return Value(); - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 83681b2d5fd87..53a19129103a3 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -87,12 +87,12 @@ getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, /// aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { - Value one = rewriter.create(loc, alignment.getType(), - rewriter.getIndexAttr(1)); - Value bump = rewriter.create(loc, alignment, one); - Value bumped = rewriter.create(loc, input, bump); - Value mod = rewriter.create(loc, bumped, alignment); - return rewriter.create(loc, bumped, mod); + Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(), + rewriter.getIndexAttr(1)); + Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one); + Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump); + Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment); + return LLVM::SubOp::create(rewriter, loc, bumped, mod); } /// Computes the byte size for the MemRef element type. @@ -123,8 +123,9 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space"); unsigned memrefAddrSpace = *maybeMemrefAddrSpace; if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) - allocatedPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), + allocatedPtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, + LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), allocatedPtr); return allocatedPtr; } @@ -168,14 +169,14 @@ class AllocOpLowering : public ConvertOpToLLVMPattern { Value alignment = getAlignment(rewriter, loc, op); if (alignment) { // Adjust the allocation size to consider alignment. - sizeBytes = rewriter.create(loc, sizeBytes, alignment); + sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment); } // Allocate the underlying buffer. Type elementPtrType = this->getElementPtrType(memRefType); assert(elementPtrType && "could not compute element ptr type"); auto results = - rewriter.create(loc, allocFuncOp.value(), sizeBytes); + LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -184,11 +185,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern { if (alignment) { // Compute the aligned pointer. Value allocatedInt = - rewriter.create(loc, getIndexType(), allocatedPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = - rewriter.create(loc, elementPtrType, alignmentInt); + LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt); } // Create the MemRef descriptor. @@ -268,8 +269,9 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern { sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - auto results = rewriter.create( - loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); + auto results = + LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), + ValueRange({allocAlignment, sizeBytes})); Value ptr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -360,8 +362,9 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern { auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); - auto allocatedElementPtr = rewriter.create( - loc, elementPtrType, elementType, size, op.getAlignment().value_or(0)); + auto allocatedElementPtr = + LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size, + op.getAlignment().value_or(0)); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( @@ -397,7 +400,7 @@ struct AllocaScopeOpLowering remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); - rewriter.create(loc, ValueRange(), remainingOpsBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); } // Inline body region. @@ -407,8 +410,8 @@ struct AllocaScopeOpLowering // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); - auto stackSaveOp = rewriter.create(loc, getPtrType()); - rewriter.create(loc, ValueRange(), beforeBody); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); + LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -420,7 +423,7 @@ struct AllocaScopeOpLowering // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); - rewriter.create(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); @@ -451,11 +454,11 @@ struct AssumeAlignmentOpLowering // This is more direct than ptrtoint-based checks, is explicitly supported, // and works with non-integral address spaces. Value trueCond = - rewriter.create(loc, rewriter.getBoolAttr(true)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); Value alignmentConst = createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); - rewriter.create(loc, trueCond, LLVM::AssumeAlignTag(), ptr, - alignmentConst); + LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr, + alignmentConst); rewriter.replaceOp(op, memref); return success(); } @@ -559,18 +562,19 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { // Get pointer to offset field of memref descriptor. auto indexPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); - Value offsetPtr = rewriter.create( - loc, indexPtrTy, elementType, underlyingRankedDesc, - ArrayRef{0, 2}); + Value offsetPtr = + LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType, + underlyingRankedDesc, ArrayRef{0, 2}); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. - Value idxPlusOne = rewriter.create( - loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), + Value idxPlusOne = LLVM::AddOp::create( + rewriter, loc, + createIndexAttrConstant(rewriter, loc, getIndexType(), 1), adaptor.getIndex()); - Value sizePtr = rewriter.create( - loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, - idxPlusOne); + Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, + getTypeConverter()->getIndexType(), + offsetPtr, idxPlusOne); return rewriter .create(loc, getTypeConverter()->getIndexType(), sizePtr) .getResult(); @@ -674,9 +678,10 @@ struct GenericAtomicRMWOpLowering auto memRefType = cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr( rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices()); - Value init = rewriter.create( - loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); - rewriter.create(loc, init, loopBlock); + Value init = LLVM::LoadOp::create( + rewriter, loc, typeConverter->convertType(memRefType.getElementType()), + dataPtr); + LLVM::BrOp::create(rewriter, loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); @@ -696,15 +701,16 @@ struct GenericAtomicRMWOpLowering // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto cmpxchg = rewriter.create( - loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); + auto cmpxchg = + LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument, + result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create(loc, cmpxchg, 0); - Value ok = rewriter.create(loc, cmpxchg, 1); + Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0); + Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1); // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), - loopBlock, newLoaded); + LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef(), + loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); @@ -796,8 +802,8 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { if (!isExternal && isUninitialized) { rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { - rewriter.create(newGlobal.getLoc(), arrayTy)}; - rewriter.create(newGlobal.getLoc(), undef); + LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)}; + LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef); } return success(); } @@ -842,13 +848,13 @@ struct GetGlobalMemrefOpLowering Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); auto addressOf = - rewriter.create(loc, ptrTy, op.getName()); + LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. - auto gep = rewriter.create( - loc, ptrTy, arrayTy, addressOf, - SmallVector(type.getRank() + 1, 0)); + auto gep = + LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf, + SmallVector(type.getRank() + 1, 0)); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to @@ -857,7 +863,7 @@ struct GetGlobalMemrefOpLowering Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = - rewriter.create(loc, ptrTy, deadBeefConst); + LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. @@ -1009,8 +1015,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { loc, adaptor.getSource(), rewriter); // rank = ConstantOp srcRank - auto rankVal = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(rank)); + auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(rank)); // poison = PoisonOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType); @@ -1029,7 +1035,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // struct = LoadOp ptr - auto loadOp = rewriter.create(loc, targetStructType, ptr); + auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); @@ -1063,32 +1069,33 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor srcDesc(adaptor.getSource()); // Compute number of elements. - Value numElements = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(1)); + Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); - numElements = rewriter.create(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. Value totalSize = - rewriter.create(loc, numElements, sizeInBytes); + LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes); Type elementType = typeConverter->convertType(srcType.getElementType()); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); Value srcOffset = srcDesc.offset(rewriter, loc); - Value srcPtr = rewriter.create( - loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); + Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), + elementType, srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.getTarget()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); Value targetOffset = targetDesc.offset(rewriter, loc); - Value targetPtr = rewriter.create( - loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); - rewriter.create(loc, targetPtr, srcPtr, totalSize, - /*isVolatile=*/false); + Value targetPtr = + LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType, + targetBasePtr, targetOffset); + LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize, + /*isVolatile=*/false); rewriter.eraseOp(op); return success(); @@ -1103,8 +1110,8 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { - auto rank = rewriter.create(loc, getIndexType(), - type.getRank()); + auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + type.getRank()); auto *typeConverter = getTypeConverter(); auto ptr = typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); @@ -1116,7 +1123,7 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { }; // Save stack position before promoting descriptors - auto stackSaveOp = rewriter.create(loc, getPtrType()); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); auto srcMemRefType = dyn_cast(srcType); Value unrankedSource = @@ -1128,13 +1135,13 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. - auto one = rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1)); auto promote = [&](Value desc) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); auto allocated = - rewriter.create(loc, ptrType, desc.getType(), one); - rewriter.create(loc, desc, allocated); + LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one); + LLVM::StoreOp::create(rewriter, loc, desc, allocated); return allocated; }; @@ -1149,11 +1156,11 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { sourcePtr.getType(), symbolTables); if (failed(copyFn)) return failure(); - rewriter.create(loc, copyFn.value(), - ValueRange{elemSize, sourcePtr, targetPtr}); + LLVM::CallOp::create(rewriter, loc, copyFn.value(), + ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors - rewriter.create(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); rewriter.eraseOp(op); @@ -1204,9 +1211,9 @@ struct MemorySpaceCastOpLowering MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, descVals); descVals[0] = - rewriter.create(loc, newPtrType, descVals[0]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]); descVals[1] = - rewriter.create(loc, newPtrType, descVals[1]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]); Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), resultTypeR, descVals); rewriter.replaceOp(op, result); @@ -1241,8 +1248,9 @@ struct MemorySpaceCastOpLowering UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), result, resultAddrSpace, sizes); Value resultUnderlyingSize = sizes.front(); - Value resultUnderlyingDesc = rewriter.create( - loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); + Value resultUnderlyingDesc = + LLVM::AllocaOp::create(rewriter, loc, getPtrType(), + rewriter.getI8Type(), resultUnderlyingSize); result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); // Copy pointers, performing address space casts. @@ -1256,10 +1264,10 @@ struct MemorySpaceCastOpLowering Value alignedPtr = sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), sourceUnderlyingDesc, sourceElemPtrType); - allocatedPtr = rewriter.create( - loc, resultElemPtrType, allocatedPtr); - alignedPtr = rewriter.create( - loc, resultElemPtrType, alignedPtr); + allocatedPtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, resultElemPtrType, allocatedPtr); + alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, + resultElemPtrType, alignedPtr); result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, resultElemPtrType, allocatedPtr); @@ -1277,12 +1285,13 @@ struct MemorySpaceCastOpLowering int64_t bytesToSkip = 2 * llvm::divideCeil( getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); - Value bytesToSkipConst = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); - Value copySize = rewriter.create( - loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); - rewriter.create(loc, resultIndexVals, sourceIndexVals, - copySize, /*isVolatile=*/false); + Value bytesToSkipConst = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); + Value copySize = + LLVM::SubOp::create(rewriter, loc, getIndexType(), + resultUnderlyingSize, bytesToSkipConst); + LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals, + copySize, /*isVolatile=*/false); rewriter.replaceOp(op, ValueRange{result}); return success(); @@ -1485,7 +1494,7 @@ struct MemRefReshapeOpLowering } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexAttrConstant(rewriter, loc, indexType, i); - dimSize = rewriter.create(loc, shapeOp, index); + dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) dimSize = typeConverter->materializeTargetConversion( @@ -1497,7 +1506,7 @@ struct MemRefReshapeOpLowering desc.setStride(rewriter, loc, i, stride); // Prepare the stride value for the next dimension. - stride = rewriter.create(loc, stride, dimSize); + stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize); } *descriptor = desc; @@ -1522,8 +1531,9 @@ struct MemRefReshapeOpLowering SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, addressSpace, sizes); - Value underlyingDescPtr = rewriter.create( - loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front()); + Value underlyingDescPtr = LLVM::AllocaOp::create( + rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), + sizes.front()); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. @@ -1554,7 +1564,7 @@ struct MemRefReshapeOpLowering Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); Value resultRankMinusOne = - rewriter.create(loc, resultRank, oneIndex); + LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); Type indexType = getTypeConverter()->getIndexType(); @@ -1568,15 +1578,15 @@ struct MemRefReshapeOpLowering rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); rewriter.setInsertionPointToEnd(initBlock); - rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), - condBlock); + LLVM::BrOp::create(rewriter, loc, + ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); - Value pred = rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 1), + Value pred = LLVM::ICmpOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = @@ -1585,31 +1595,31 @@ struct MemRefReshapeOpLowering // Copy size from shape to descriptor. auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value sizeLoadGep = rewriter.create( - loc, llvmIndexPtrType, + Value sizeLoadGep = LLVM::GEPOp::create( + rewriter, loc, llvmIndexPtrType, typeConverter->convertType(shapeMemRefType.getElementType()), shapeOperandPtr, indexArg); - Value size = rewriter.create(loc, indexType, sizeLoadGep); + Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); - Value nextStride = rewriter.create(loc, strideArg, size); + Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size); // Decrement loop counter and branch back. - Value decrement = rewriter.create(loc, indexArg, oneIndex); - rewriter.create(loc, ValueRange({decrement, nextStride}), - condBlock); + Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex); + LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}), + condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, pred, bodyBlock, ValueRange(), - remainder, ValueRange()); + LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(), + remainder, ValueRange()); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); @@ -1738,7 +1748,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); if (nextSize) return runningStride - ? rewriter.create(loc, runningStride, nextSize) + ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexAttrConstant(rewriter, loc, indexType, 1); @@ -1783,8 +1793,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); - alignedPtr = rewriter.create( - loc, alignedPtr.getType(), + alignedPtr = LLVM::GEPOp::create( + rewriter, loc, alignedPtr.getType(), typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, adaptor.getByteShift()); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index b866afbce98b0..7a705336bf11c 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, assert(indices.size() == 2); indices.back() = builder.createOrFold(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); - return builder.create(loc, t, op.getBasePtr(), indices); + return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(), + indices); } /// Casts the given `srcBool` into an integer of `dstType`. @@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, value = castBoolToIntN(loc, value, dstType, builder); } else { if (valueBits < targetBits) { - value = builder.create( - loc, builder.getIntegerType(targetBits), value); + value = spirv::UConvertOp::create( + builder, loc, builder.getIntegerType(targetBits), value); } value = builder.createOrFold(loc, value, mask); @@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); - varOp = rewriter.create(loc, spirvType, varName, - /*initializer=*/nullptr); + varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName, + /*initializer=*/nullptr); } // Get pointer to global variable at the current scope. @@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value loadVal = rewriter.create(loc, accessChain, - memoryAccess, alignment); + Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain, + memoryAccess, alignment); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); rewriter.replaceOp(loadOp, loadVal); @@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value spvLoadOp = rewriter.create(loc, dstType, adjustedPtr, - memoryAccess, alignment); + Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr, + memoryAccess, alignment); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX @@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, if (!scope) return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); - Value result = rewriter.create( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - clearBitsMask); - result = rewriter.create( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - storeVal); + Value result = spirv::AtomicAndOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, clearBitsMask); + result = spirv::AtomicOrOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() @@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( genericPtrType = typeConverter.convertType(intermediateType); } if (sourceSc != spirv::StorageClass::Generic) { - result = - rewriter.create(loc, genericPtrType, result); + result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType, + result); } if (resultSc != spirv::StorageClass::Generic) { result = - rewriter.create(loc, resultPtrType, result); + spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result); } rewriter.replaceOp(addrCastOp, result); return success(); diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index b93128441f2b5..63b1fdabaf407 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -65,7 +65,7 @@ static SmallVector getMixedAsValues(OpBuilder b, const Location &loc, values.emplace_back(*(dyn++)); } else { TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create(loc, type, val)); + values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); } } return values; @@ -79,9 +79,9 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, SmallVector multiIndex(n); for (int i = n - 1; i >= 0; --i) { - multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); + multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]); if (i > 0) - linearIndex = b.create(loc, linearIndex, dimensions[i]); + linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]); } return multiIndex; @@ -91,13 +91,13 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { - Value linearIndex = b.create(loc, 0); - Value stride = b.create(loc, 1); + Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0); + Value stride = arith::ConstantIndexOp::create(b, loc, 1); for (int i = multiIndex.size() - 1; i >= 0; --i) { - Value off = b.create(loc, multiIndex[i], stride); - linearIndex = b.create(loc, linearIndex, off); - stride = b.create(loc, stride, dimensions[i]); + Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride); + linearIndex = arith::AddIOp::create(b, loc, linearIndex, off); + stride = arith::MulIOp::create(b, loc, stride, dimensions[i]); } return linearIndex; @@ -144,11 +144,12 @@ struct ConvertShardingOp : public OpConversionPattern { auto i64 = rewriter.getI64Type(); std::array shape = {static_cast(splitAxes.size()), maxNAxes}; - Value resSplitAxes = rewriter.create(loc, shape, i16); + Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16); auto attr = IntegerAttr::get(i16, -1); - Value fillValue = rewriter.create(loc, i16, attr); - resSplitAxes = rewriter.create(loc, fillValue, resSplitAxes) - .getResult(0); + Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr); + resSplitAxes = + linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes) + .getResult(0); // explicitly write values into tensor row by row std::array strides = {1, 1}; @@ -162,9 +163,10 @@ struct ConvertShardingOp : public OpConversionPattern { std::array sizes = {1, size}; auto tensorType = RankedTensorType::get({size}, i16); auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); - auto vals = rewriter.create(loc, tensorType, attrs); - resSplitAxes = rewriter.create( - loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); + auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs); + resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals, + resSplitAxes, empty, empty, + empty, offs, sizes, strides); } // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. @@ -179,7 +181,7 @@ struct ConvertShardingOp : public OpConversionPattern { .create(loc, std::array{0, 0}, i64) .getResult() - : rewriter.create(loc, type, haloSizes) + : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); // To hold sharded dims offsets, create Tensor with shape {nSplits, @@ -189,8 +191,8 @@ struct ConvertShardingOp : public OpConversionPattern { // MeshOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { - resOffsets = rewriter.create( - loc, std::array{0, 0}, i64); + resOffsets = tensor::EmptyOp::create(rewriter, loc, + std::array{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; auto meshOp = getMesh(op, symbolTableCollection); @@ -204,12 +206,12 @@ struct ConvertShardingOp : public OpConversionPattern { assert(maxSplitSize); ++maxSplitSize; // add one for the total size - resOffsets = rewriter.create( - loc, std::array{nSplits, maxSplitSize}, i64); - Value zero = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); + resOffsets = tensor::EmptyOp::create( + rewriter, loc, std::array{nSplits, maxSplitSize}, i64); + Value zero = arith::ConstantOp::create( + rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); resOffsets = - rewriter.create(loc, zero, resOffsets).getResult(0); + linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0); SmallVector offsets = getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), adaptor.getDynamicShardedDimsOffsets()); @@ -220,11 +222,12 @@ struct ConvertShardingOp : public OpConversionPattern { assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef values(&offsets[curr], splitSize); - Value vals = rewriter.create(loc, values); + Value vals = tensor::FromElementsOp::create(rewriter, loc, values); std::array offs = {static_cast(i), 0}; std::array sizes = {1, splitSize}; - resOffsets = rewriter.create( - loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); + resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals, + resOffsets, empty, empty, + empty, offs, sizes, strides); curr += splitSize; } } @@ -236,10 +239,10 @@ struct ConvertShardingOp : public OpConversionPattern { return failure(); resSplitAxes = - rewriter.create(loc, resTypes[0], resSplitAxes); + tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes); resHaloSizes = - rewriter.create(loc, resTypes[1], resHaloSizes); - resOffsets = rewriter.create(loc, resTypes[2], resOffsets); + tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes); + resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets); rewriter.replaceOpWithNewOp( op, TupleType::get(op.getContext(), resTypes), @@ -269,9 +272,9 @@ struct ConvertProcessMultiIndexOp SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = rewriter.create(op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); // optionally extract subset of mesh axes @@ -302,7 +305,7 @@ class ConvertProcessLinearIndexOp Location loc = op.getLoc(); auto ctx = op.getContext(); Value commWorld = - rewriter.create(loc, mpi::CommType::get(ctx)); + mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); auto rank = rewriter .create( @@ -341,41 +344,41 @@ struct ConvertNeighborsLinearIndicesOp SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; - Value one = rewriter.create(loc, 1); - Value minus1 = rewriter.create(loc, -1); - Value atBorder = rewriter.create( - loc, arith::CmpIPredicate::sle, orgIdx, - rewriter.create(loc, 0)); - auto down = rewriter.create( - loc, atBorder, + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1); + Value atBorder = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx, + arith::ConstantIndexOp::create(rewriter, loc, 0)); + auto down = scf::IfOp::create( + rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, one) + arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one) .getResult(); - builder.create( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + scf::YieldOp::create(builder, loc, + multiToLinearIndex(loc, rewriter, tmp, dims)); }); - atBorder = rewriter.create( - loc, arith::CmpIPredicate::sge, orgIdx, - rewriter.create(loc, dimSz, one).getResult()); - auto up = rewriter.create( - loc, atBorder, + atBorder = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, orgIdx, + arith::SubIOp::create(rewriter, loc, dimSz, one).getResult()); + auto up = scf::IfOp::create( + rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, one); - builder.create( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one); + scf::YieldOp::create(builder, loc, + multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); return success(); @@ -447,8 +450,9 @@ struct ConvertShardShapeOp : public OpConversionPattern { rewriter, loc, sharding.getStaticShardedDimsOffsets(), sharding.getDynamicShardedDimsOffsets(), index); if (!tmp.empty()) - shardedDimsOffs = rewriter.create( - loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); + shardedDimsOffs = tensor::FromElementsOp::create( + rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index), + tmp); } // With static mesh shape the sizes of the split axes are known. @@ -457,9 +461,9 @@ struct ConvertShardShapeOp : public OpConversionPattern { int64_t pos = 0; SmallVector shardShape; Value zero = - rewriter.create(loc, rewriter.getZeroAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index)); Value one = - rewriter.create(loc, rewriter.getOneAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index)); // Iterate over the dimensions of the tensor shape, get their split Axes, // and compute the sharded shape. @@ -469,8 +473,8 @@ struct ConvertShardShapeOp : public OpConversionPattern { auto axes = splitAxes[i]; // The current dimension might not be sharded. // Create a value from the static position in shardDimsOffsets. - Value posVal = - rewriter.create(loc, rewriter.getIndexAttr(pos)); + Value posVal = arith::ConstantOp::create(rewriter, loc, + rewriter.getIndexAttr(pos)); // Get the index of the local shard in the mesh axis. Value idx = multiIdx[axes[0]]; auto numShards = @@ -482,29 +486,29 @@ struct ConvertShardShapeOp : public OpConversionPattern { return op->emitError() << "Only single axis sharding is " << "supported for each dimension."; } - idx = rewriter.create(loc, posVal, idx); + idx = arith::AddIOp::create(rewriter, loc, posVal, idx); // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. Value off = - rewriter.create(loc, shardedDimsOffs, idx); - idx = rewriter.create(loc, idx, one); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + idx = arith::AddIOp::create(rewriter, loc, idx, one); Value nextOff = - rewriter.create(loc, shardedDimsOffs, idx); - Value sz = rewriter.create(loc, nextOff, off); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off); shardShape.emplace_back(sz); } else { - Value numShardsVal = rewriter.create( - loc, rewriter.getIndexAttr(numShards)); + Value numShardsVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(numShards)); // Compute shard dim size by distributing odd elements to trailing // shards: // sz = dim / numShards // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) - Value sz = rewriter.create(loc, dim, numShardsVal); - Value sz1 = rewriter.create(loc, dim, numShardsVal); - sz1 = rewriter.create(loc, numShardsVal, sz1); - auto cond = rewriter.create( - loc, arith::CmpIPredicate::sge, idx, sz1); - Value odd = rewriter.create(loc, cond, one, zero); - sz = rewriter.create(loc, sz, odd); + Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal); + Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal); + sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1); + auto cond = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, idx, sz1); + Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero); + sz = arith::AddIOp::create(rewriter, loc, sz, odd); shardShape.emplace_back(sz); } pos += numShards + 1; // add one for the total size. @@ -568,7 +572,7 @@ struct ConvertAllReduceOp : public OpConversionPattern { if (isa(input.getType())) { auto memrefType = MemRefType::get( inputShape, cast(input.getType()).getElementType()); - input = iBuilder.create(memrefType, input); + input = bufferization::ToBufferOp::create(iBuilder, memrefType, input); } MemRefType inType = cast(input.getType()); @@ -577,15 +581,15 @@ struct ConvertAllReduceOp : public OpConversionPattern { for (auto i = 0; i < inType.getRank(); ++i) { auto s = inputShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = iBuilder.create(input, s).getResult(); + shape[i] = memref::DimOp::create(iBuilder, input, s).getResult(); else shape[i] = iBuilder.getIndexAttr(s); } // Allocate buffer and copy input to buffer. - Value buffer = iBuilder.create( - shape, cast(op.getType()).getElementType()); - iBuilder.create(input, buffer); + Value buffer = memref::AllocOp::create( + iBuilder, shape, cast(op.getType()).getElementType()); + linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. // The color is the linear index of the process in the mesh along the @@ -594,9 +598,9 @@ struct ConvertAllReduceOp : public OpConversionPattern { SmallVector indexResultTypes(meshOp.getShape().size(), iBuilder.getIndexType()); SmallVector myMultiIndex = - iBuilder.create(indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) .getResult(); - Value zero = iBuilder.create(0); + Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector multiKey(myMultiIndex.size(), zero); auto redAxes = adaptor.getMeshAxes(); @@ -607,15 +611,15 @@ struct ConvertAllReduceOp : public OpConversionPattern { Value color = createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); - color = iBuilder.create(iBuilder.getI32Type(), color); + color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); - key = iBuilder.create(iBuilder.getI32Type(), key); + key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator auto commType = mpi::CommType::get(op->getContext()); - Value commWorld = iBuilder.create(commType); + Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); auto comm = - iBuilder.create(commType, commWorld, color, key) + mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) .getNewcomm(); Value buffer1d = buffer; @@ -623,19 +627,19 @@ struct ConvertAllReduceOp : public OpConversionPattern { if (inType.getRank() > 1) { ReassociationIndices reassociation(inType.getRank()); std::iota(reassociation.begin(), reassociation.end(), 0); - buffer1d = iBuilder.create( - buffer, ArrayRef(reassociation)); + buffer1d = memref::CollapseShapeOp::create( + iBuilder, buffer, ArrayRef(reassociation)); } // Create the MPI AllReduce operation. - iBuilder.create( - TypeRange(), buffer1d, buffer1d, - getMPIReductionOp(adaptor.getReductionAttr()), comm); + mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d, + getMPIReductionOp(adaptor.getReductionAttr()), + comm); // If the destination is a memref, cast it to a tensor if (isa(op.getType())) - buffer = iBuilder.create(op.getType(), buffer, - true); + buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer, + true); rewriter.replaceOp(op, buffer); return success(); @@ -676,9 +680,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { if (auto value = dyn_cast(v)) return value; - return rewriter.create( - loc, rewriter.getIndexAttr( - cast(cast(v)).getInt())); + return arith::ConstantOp::create( + rewriter, loc, + rewriter.getIndexAttr( + cast(cast(v)).getInt())); }; auto dest = adaptor.getDestination(); @@ -689,7 +694,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { auto mmemrefType = MemRefType::get( dstShape, cast(array.getType()).getElementType()); array = - rewriter.create(loc, mmemrefType, array); + bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array); } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); @@ -713,7 +718,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { for (auto i = 0; i < rank; ++i) { auto s = dstShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = rewriter.create(loc, array, s).getResult(); + shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult(); else shape[i] = rewriter.getIndexAttr(s); @@ -723,12 +728,12 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { offsets[i] = haloSizes[currHaloDim * 2]; // prepare shape and offsets of highest dim's halo exchange - Value _haloSz = rewriter.create( - loc, toValue(haloSizes[currHaloDim * 2]), + Value _haloSz = arith::AddIOp::create( + rewriter, loc, toValue(haloSizes[currHaloDim * 2]), toValue(haloSizes[currHaloDim * 2 + 1])); // the halo shape of lower dims exlude the halos dimSizes[i] = - rewriter.create(loc, toValue(shape[i]), _haloSz) + arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz) .getResult(); } else { dimSizes[i] = shape[i]; @@ -736,14 +741,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something - auto tag = rewriter.create(loc, tagAttr); + auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 - auto zero = rewriter.create(loc, zeroAttr); + auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); SmallVector indexResultTypes(meshOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - rewriter.create(loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -758,20 +763,22 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { splitAxes) .getResults(); // MPI operates on i32... - Value neighbourIDs[2] = {rewriter.create( - loc, rewriter.getI32Type(), tmp[0]), - rewriter.create( - loc, rewriter.getI32Type(), tmp[1])}; + Value neighbourIDs[2] = { + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), + tmp[0]), + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), + tmp[1])}; auto lowerRecvOffset = rewriter.getIndexAttr(0); auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); - auto upperRecvOffset = rewriter.create( - loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); - auto upperSendOffset = rewriter.create( - loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); + auto upperRecvOffset = + arith::SubIOp::create(rewriter, loc, toValue(shape[dim]), + toValue(haloSizes[currHaloDim * 2 + 1])); + auto upperSendOffset = arith::SubIOp::create( + rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); - Value commWorld = rewriter.create( - loc, mpi::CommType::get(op->getContext())); + Value commWorld = mpi::CommWorldOp::create( + rewriter, loc, mpi::CommType::get(op->getContext())); // Make sure we send/recv in a way that does not lead to a dead-lock. // The current approach is by far not optimal, this should be at least @@ -787,37 +794,38 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { // Processes on the mesh borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto hasFrom = rewriter.create( - loc, arith::CmpIPredicate::sge, from, zero); - auto hasTo = rewriter.create( - loc, arith::CmpIPredicate::sge, to, zero); - auto buffer = rewriter.create( - loc, dimSizes, cast(array.getType()).getElementType()); + auto hasFrom = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, from, zero); + auto hasTo = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::sge, to, zero); + auto buffer = memref::AllocOp::create( + rewriter, loc, dimSizes, + cast(array.getType()).getElementType()); // if has neighbor: copy halo data from array to buffer and send - rewriter.create( - loc, hasTo, [&](OpBuilder &builder, Location loc) { + scf::IfOp::create( + rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) : OpFoldResult(upperSendOffset); - auto subview = builder.create( - loc, array, offsets, dimSizes, strides); - builder.create(loc, subview, buffer); - builder.create(loc, TypeRange{}, buffer, tag, to, - commWorld); - builder.create(loc); + auto subview = memref::SubViewOp::create( + builder, loc, array, offsets, dimSizes, strides); + memref::CopyOp::create(builder, loc, subview, buffer); + mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to, + commWorld); + scf::YieldOp::create(builder, loc); }); // if has neighbor: receive halo data into buffer and copy to array - rewriter.create( - loc, hasFrom, [&](OpBuilder &builder, Location loc) { + scf::IfOp::create( + rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) : OpFoldResult(lowerRecvOffset); - builder.create(loc, TypeRange{}, buffer, tag, from, - commWorld); - auto subview = builder.create( - loc, array, offsets, dimSizes, strides); - builder.create(loc, buffer, subview); - builder.create(loc); + mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from, + commWorld); + auto subview = memref::SubViewOp::create( + builder, loc, array, offsets, dimSizes, strides); + memref::CopyOp::create(builder, loc, buffer, subview); + scf::YieldOp::create(builder, loc); }); - rewriter.create(loc, buffer); + memref::DeallocOp::create(rewriter, loc, buffer); offsets[dim] = orgOffset; }; @@ -825,16 +833,17 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; Value haloSz = dyn_cast(v); if (!haloSz) - haloSz = rewriter.create( - loc, rewriter.getI32IntegerAttr( - cast(cast(v)).getInt())); - auto hasSize = rewriter.create( - loc, arith::CmpIPredicate::sgt, haloSz, zero); - rewriter.create(loc, hasSize, - [&](OpBuilder &builder, Location loc) { - genSendRecv(upOrDown > 0); - builder.create(loc); - }); + haloSz = arith::ConstantOp::create( + rewriter, loc, + rewriter.getI32IntegerAttr( + cast(cast(v)).getInt())); + auto hasSize = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero); + scf::IfOp::create(rewriter, loc, hasSize, + [&](OpBuilder &builder, Location loc) { + genSendRecv(upOrDown > 0); + scf::YieldOp::create(builder, loc); + }); }; doSendRecv(0); @@ -852,8 +861,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { rewriter.replaceOp(op, array); } else { assert(isa(op.getResult().getType())); - rewriter.replaceOp(op, rewriter.create( - loc, op.getResult().getType(), array, + rewriter.replaceOp(op, bufferization::ToTensorOp::create( + rewriter, loc, op.getResult().getType(), array, /*restrict=*/true, /*writable=*/true)); } return success(); diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85488495..905287e107b0b 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -53,7 +53,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { assert(llvm::isa(type) && "expected an integer Value"); if (type.getIntOrFloatBitWidth() <= 32) return value; - return b.create(b.getI32Type(), value); + return LLVM::TruncOp::create(b, b.getI32Type(), value); } /// Returns the type for the intrinsic given the vectorResultType of the @@ -113,8 +113,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { - return rewriter.create(loc, IntegerType::get(ctx, 32), - rewriter.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); }; if (arrayType) { @@ -126,7 +126,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = - rewriter.create(loc, intrinsicResult, i); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i); el = rewriter.createOrFold( loc, arrayType.getElementType(), el); elements.push_back(el); @@ -143,24 +143,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = - rewriter.create(loc, arrayType.getElementType()); + LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType()); Value x1 = - rewriter.create(loc, intrinsicResult, i * 2); - Value x2 = rewriter.create(loc, intrinsicResult, - i * 2 + 1); - vec = rewriter.create(loc, vec.getType(), vec, - x1, makeConst(0)); - vec = rewriter.create(loc, vec.getType(), vec, - x2, makeConst(1)); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2); + Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, + i * 2 + 1); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x1, makeConst(0)); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. - Value result = rewriter.create(loc, arrayType); + Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { - result = rewriter.create(loc, result, el.value(), - el.index()); + result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(), + el.index()); } return result; } @@ -187,7 +187,7 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, auto arrayTy = cast(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { - Value toUse = b.create(operand, i); + Value toUse = LLVM::ExtractValueOp::create(b, operand, i); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. @@ -195,7 +195,7 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { - result.push_back(b.create(i32Ty, toUse)); + result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse)); continue; } @@ -208,9 +208,9 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { - result.push_back(b.create( - toUse, - b.create(i64Ty, b.getI64IntegerAttr(idx)))); + result.push_back(LLVM::ExtractElementOp::create( + b, toUse, + LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx)))); } continue; } @@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); - Value ldMatrixResult = b.create( - ldMatrixResultType, srcPtr, + Value ldMatrixResult = NVVM::LdMatrixOp::create( + b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); @@ -296,13 +296,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); - Value result = b.create(finalResultType); + Value result = LLVM::PoisonOp::create(b, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = - num32BitRegs > 1 ? b.create(ldMatrixResult, i) + num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i) : ldMatrixResult; - Value casted = b.create(innerVectorType, i32Register); - result = b.create(result, casted, i); + Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register); + result = LLVM::InsertValueOp::create(b, result, casted, i); } rewriter.replaceOp(op, result); @@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); - Value intrinsicResult = b.create( - intrinsicResTy, matA, matB, matC, - /*shape=*/gemmShape, - /*b1Op=*/std::nullopt, - /*intOverflow=*/overflow, - /*multiplicandPtxTypes=*/ - std::array{*ptxTypeA, *ptxTypeB}, - /*multiplicandLayouts=*/ - std::array{NVVM::MMALayout::row, - NVVM::MMALayout::col}); + Value intrinsicResult = + NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/std::nullopt, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{*ptxTypeA, *ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{ + NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); @@ -565,15 +565,16 @@ static FailureOr emitMmaSparseSyncOpAsm( llvm::append_range(asmVals, args); asmVals.push_back(indexData); - return b.create( - /*resultTypes=*/intrinsicResultType, - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/constraintStr, - /*has_side_effects=*/true, - /*is_align_stack=*/false, LLVM::TailCallKind::None, - /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); + return LLVM::InlineAsmOp::create(b, + /*resultTypes=*/intrinsicResultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/constraintStr, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::TailCallKind::None, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); } /// Lowers `nvgpu.mma.sp.sync` to inline assembly. @@ -631,7 +632,7 @@ struct NVGPUMmaSparseSyncLowering return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = - b.create(rewriter.getI32Type(), sparseMetadata); + LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata); FailureOr intrinsicResult = emitMmaSparseSyncOpAsm( b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, @@ -682,7 +683,7 @@ struct NVGPUAsyncCopyLowering // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); - scrPtr = b.create(srcPointerGlobalType, scrPtr); + scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr); int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; @@ -697,13 +698,13 @@ struct NVGPUAsyncCopyLowering // The rest of the DstElements in the destination (shared memory) are // filled with zeros. Value c3I32 = - b.create(b.getI32Type(), b.getI32IntegerAttr(3)); - Value bitwidth = b.create( - b.getI32Type(), + LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3)); + Value bitwidth = LLVM::ConstantOp::create( + b, b.getI32Type(), b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); - Value srcElementsI32 = b.create(b.getI32Type(), srcBytes); - srcBytes = b.create( - b.create(bitwidth, srcElementsI32), c3I32); + Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes); + srcBytes = LLVM::LShrOp::create( + b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32); } // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than // 16 dst bytes. @@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering ? NVVM::LoadCacheModifierKind::CG : NVVM::LoadCacheModifierKind::CA; - b.create( - dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), + NVVM::CpAsyncOp::create( + b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), srcBytes); // Drop the result token. - Value zero = b.create( - IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); + Value zero = + LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -733,11 +735,11 @@ struct NVGPUAsyncCreateGroupLowering LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create(op.getLoc()); + NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc()); // Drop the result token. - Value zero = rewriter.create( - op->getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -753,7 +755,7 @@ struct NVGPUAsyncWaitLowering ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); - rewriter.create(op.getLoc(), numGroups); + NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } @@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering SymbolTable symbolTable(moduleOp); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&moduleOp.front()); - auto global = rewriter.create( - funcOp->getLoc(), "__mbarrier", + auto global = memref::GlobalOp::create( + rewriter, funcOp->getLoc(), "__mbarrier", /*sym_visibility=*/rewriter.getStringAttr("private"), /*type=*/barrierType, /*initial_value=*/ElementsAttr(), @@ -974,7 +976,7 @@ struct NVGPUMBarrierTryWaitParityLowering adaptor.getMbarId(), rewriter); Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = - b.create(b.getI32Type(), adaptor.getPhaseParity()); + LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); if (isMbarrierShared(op.getBarriers().getType())) { rewriter.replaceOpWithNewOp( @@ -1063,16 +1065,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering auto ti64 = b.getIntegerType(64); auto makeConst = [&](uint64_t index) -> Value { - return b.create(ti64, b.getI64IntegerAttr(index)); + return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index)); }; auto shiftLeft = [&](Value value, unsigned shift) -> Value { - return b.create(ti64, value, makeConst(shift)); + return LLVM::ShlOp::create(b, ti64, value, makeConst(shift)); }; auto shiftRight = [&](Value value, unsigned shift) -> Value { - return b.create(ti64, value, makeConst(shift)); + return LLVM::LShrOp::create(b, ti64, value, makeConst(shift)); }; auto insertBit = [&](Value desc, Value val, int startBit) { - return b.create(ti64, desc, shiftLeft(val, startBit)); + return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit)); }; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); @@ -1086,7 +1088,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering Value baseAddr = getStridedElementPtr( rewriter, op->getLoc(), cast(op.getTensor().getType()), adaptor.getTensor(), {}); - Value basePtr = b.create(ti64, baseAddr); + Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr); // Just use 14 bits for base address Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); @@ -1118,8 +1120,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering }; static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { - return b.create(b.getIntegerType(64), - b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, b.getIntegerType(64), + b.getI32IntegerAttr(index)); } /// Returns a Value that holds data type enum that is expected by CUDA driver. @@ -1182,12 +1184,12 @@ struct NVGPUTmaCreateDescriptorOpLowering auto promotedOperands = getTypeConverter()->promoteOperands( b.getLoc(), op->getOperands(), adaptor.getOperands(), b); - Value boxArrayPtr = b.create(llvmPointerType, llvmInt64Type, - makeI64Const(b, 5)); + Value boxArrayPtr = LLVM::AllocaOp::create( + b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5)); for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { - Value gep = b.create(llvmPointerType, llvmPointerType, - boxArrayPtr, makeI64Const(b, index)); - b.create(value, gep); + Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType, + boxArrayPtr, makeI64Const(b, index)); + LLVM::StoreOp::create(b, value, gep); } nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); @@ -1337,7 +1339,7 @@ struct NVGPUWarpgroupMmaOpLowering /// Basic function to generate Add Value makeAdd(Value lhs, Value rhs) { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. @@ -1430,29 +1432,30 @@ struct NVGPUWarpgroupMmaOpLowering auto overflow = NVVM::MMAIntOverflowAttr::get( op->getContext(), NVVM::MMAIntOverflow::wrapped); - return b.create( - matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, - itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, + return NVVM::WgmmaMmaAsyncOp::create( + b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape, + itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); } /// Generates multiple wgmma instructions to complete the given GEMM shape Value generateWgmmaGroup() { Value wgmmaResult = - b.create(adaptor.getMatrixC().getType()); + LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType()); // Perform GEMM SmallVector wgmmaResults; for (int i = 0; i < iterationM; ++i) { - Value matrixC = b.create(adaptor.getMatrixC(), i); + Value matrixC = + LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i); for (int j = 0; j < iterationN; ++j) for (int k = 0; k < iterationK; ++k) matrixC = generateWgmma(i, j, k, matrixC); wgmmaResults.push_back(matrixC); } for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { - wgmmaResult = b.create(wgmmaResult.getType(), - wgmmaResult, matrix, idx); + wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(), + wgmmaResult, matrix, idx); } return wgmmaResult; } @@ -1486,10 +1489,10 @@ struct NVGPUWarpgroupMmaOpLowering /// (WgmmaGroupSyncAlignedOp) for group synchronization /// (WgmmaWaitGroupSyncOp) after the instructions. Value generateWarpgroupMma() { - b.create(); + NVVM::WgmmaFenceAlignedOp::create(b); Value wgmmaResult = generateWgmmaGroup(); - b.create(); - b.create(op.getWaitGroup()); + NVVM::WgmmaGroupSyncAlignedOp::create(b); + NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup()); return wgmmaResult; } }; @@ -1557,7 +1560,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { - return b.create(i32, b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index)); }; Value c1 = makeConst(1); Value c2 = makeConst(2); @@ -1567,29 +1570,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering Value warpSize = makeConst(kWarpSize); auto makeMul = [&](Value lhs, Value rhs) -> Value { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs); }; auto makeAdd = [&](Value lhs, Value rhs) -> Value { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, TypedValue<::mlir::MemRefType> memref) { Type it = b.getIndexType(); - Value idx = b.create(it, x); - Value idy0 = b.create(it, y); - Value idy1 = b.create(it, makeAdd(y, c1)); - Value d0 = b.create(wgmmaResult, i); - Value d1 = b.create(wgmmaResult, i + 1); - b.create(d0, memref, ValueRange{idx, idy0}); - b.create(d1, memref, ValueRange{idx, idy1}); + Value idx = arith::IndexCastOp::create(b, it, x); + Value idy0 = arith::IndexCastOp::create(b, it, y); + Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1)); + Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i); + Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1); + memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0}); + memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1}); }; - Value tidx = b.create(i32); - Value laneId = b.create(i32, tidx, warpSize); - Value warpId = b.create(i32, tidx, warpSize); - Value lane4Id = b.create(i32, laneId, c4); - Value lane4modId = b.create(i32, laneId, c4); + Value tidx = NVVM::ThreadIdXOp::create(b, i32); + Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize); + Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize); + Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4); + Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4); Value tj = makeMul(lane4modId, c2); Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); @@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering auto stype = cast(matriDValue.getType()); for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { auto structType = cast(matrixD); - Value innerStructValue = b.create(matriDValue, idx); + Value innerStructValue = + LLVM::ExtractValueOp::create(b, matriDValue, idx); storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); offset += structType.getBody().size(); } @@ -1648,23 +1652,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering Type elemType = cast(packStructType.getBody().front()) .getBody() .front(); - Value zero = b.create(elemType, b.getZeroAttr(elemType)); - Value packStruct = b.create(packStructType); + Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType)); + Value packStruct = LLVM::PoisonOp::create(b, packStructType); SmallVector innerStructs; // Unpack the structs and set all values to zero for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { auto structType = cast(s); - Value structValue = b.create(packStruct, idx); + Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx); for (unsigned i = 0; i < structType.getBody().size(); ++i) { - structValue = b.create( - structType, structValue, zero, ArrayRef({i})); + structValue = LLVM::InsertValueOp::create(b, structType, structValue, + zero, ArrayRef({i})); } innerStructs.push_back(structValue); } // Pack the inner structs into a single struct for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { - packStruct = b.create(packStruct.getType(), - packStruct, matrix, idx); + packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(), + packStruct, matrix, idx); } rewriter.replaceOp(op, packStruct); return success(); @@ -1681,7 +1685,7 @@ struct NVGPUTmaFenceOpLowering ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto i32Ty = b.getI32Type(); Value tensormapSize = - b.create(i32Ty, rewriter.getI32IntegerAttr(128)); + LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128)); auto memscope = NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); @@ -1716,13 +1720,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern { VectorType inTy = op.getIn().getType(); // apply rcp.approx.ftz.f on each element in vector. auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { - Value ret1DVec = b.create(llvm1DVectorTy); + Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy); int numElems = llvm::cast(llvm1DVectorTy).getNumElements(); for (int i = 0; i < numElems; i++) { - Value idx = b.create(i64Ty, b.getI64IntegerAttr(i)); - Value elem = b.create(inVec, idx); - Value dst = b.create(f32Ty, elem); - ret1DVec = b.create(ret1DVec, dst, idx); + Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i)); + Value elem = LLVM::ExtractElementOp::create(b, inVec, idx); + Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem); + ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx); } return ret1DVec; }; diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp index 479725aae8afd..f5b3689c88d26 100644 --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -39,8 +39,8 @@ class ExpandIfCondition : public OpRewritePattern { IntegerAttr constAttr; if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { - auto ifOp = rewriter.create(op.getLoc(), TypeRange(), - op.getIfCond(), false); + auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), TypeRange(), + op.getIfCond(), false); rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener()); thenBodyBuilder.clone(*op.getOperation()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 7ac9687c4eeda..021e31a8ecd97 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -95,8 +95,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern { } // Create new operation. - auto newOp = rewriter.create(op.getLoc(), resTypes, convertedOperands, - convertedAttrs); + auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands, + convertedAttrs); // Translate regions. for (auto [originalRegion, convertedRegion] : diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 7d20109b3db59..b711e33cfc0d6 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -196,7 +196,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, // finalize. if (isa(node)) { builder.setInsertionPointToEnd(block); - builder.create(matcherFunc.getLoc()); + pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc()); return block; } @@ -272,8 +272,8 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { auto *operationPos = cast(pos); if (operationPos->isOperandDefiningOp()) // Standard (downward) traversal which directly follows the defining op. - value = builder.create( - loc, builder.getType(), parentVal); + value = pdl_interp::GetDefiningOpOp::create( + builder, loc, builder.getType(), parentVal); else // A passthrough operation position. value = parentVal; @@ -287,23 +287,23 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { // requested to use a representative value (e.g., upward traversal). if (isa(parentVal.getType()) && usersPos->useRepresentative()) - value = builder.create(loc, parentVal, 0); + value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0); else value = parentVal; // The second operation retrieves the users. - value = builder.create(loc, value); + value = pdl_interp::GetUsersOp::create(builder, loc, value); break; } case Predicates::ForEachPos: { assert(!failureBlockStack.empty() && "expected valid failure block"); - auto foreach = builder.create( - loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); + auto foreach = pdl_interp::ForEachOp::create( + builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); value = foreach.getLoopVariable(); // Create the continuation block. Block *continueBlock = builder.createBlock(&foreach.getRegion()); - builder.create(loc); + pdl_interp::ContinueOp::create(builder, loc); failureBlockStack.push_back(continueBlock); currentBlock = &foreach.getRegion().front(); @@ -311,62 +311,64 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { } case Predicates::OperandPos: { auto *operandPos = cast(pos); - value = builder.create( - loc, builder.getType(), parentVal, + value = pdl_interp::GetOperandOp::create( + builder, loc, builder.getType(), parentVal, operandPos->getOperandNumber()); break; } case Predicates::OperandGroupPos: { auto *operandPos = cast(pos); Type valueTy = builder.getType(); - value = builder.create( - loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + value = pdl_interp::GetOperandsOp::create( + builder, loc, + operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, operandPos->getOperandGroupNumber()); break; } case Predicates::AttributePos: { auto *attrPos = cast(pos); - value = builder.create( - loc, builder.getType(), parentVal, + value = pdl_interp::GetAttributeOp::create( + builder, loc, builder.getType(), parentVal, attrPos->getName().strref()); break; } case Predicates::TypePos: { if (isa(parentVal.getType())) - value = builder.create(loc, parentVal); + value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal); else - value = builder.create(loc, parentVal); + value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal); break; } case Predicates::ResultPos: { auto *resPos = cast(pos); - value = builder.create( - loc, builder.getType(), parentVal, + value = pdl_interp::GetResultOp::create( + builder, loc, builder.getType(), parentVal, resPos->getResultNumber()); break; } case Predicates::ResultGroupPos: { auto *resPos = cast(pos); Type valueTy = builder.getType(); - value = builder.create( - loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + value = pdl_interp::GetResultsOp::create( + builder, loc, + resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, resPos->getResultGroupNumber()); break; } case Predicates::AttributeLiteralPos: { auto *attrPos = cast(pos); - value = - builder.create(loc, attrPos->getValue()); + value = pdl_interp::CreateAttributeOp::create(builder, loc, + attrPos->getValue()); break; } case Predicates::TypeLiteralPos: { auto *typePos = cast(pos); Attribute rawTypeAttr = typePos->getValue(); if (TypeAttr typeAttr = dyn_cast(rawTypeAttr)) - value = builder.create(loc, typeAttr); + value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr); else - value = builder.create( - loc, cast(rawTypeAttr)); + value = pdl_interp::CreateTypesOp::create(builder, loc, + cast(rawTypeAttr)); break; } case Predicates::ConstraintResultPos: { @@ -413,56 +415,59 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create(loc, val, success, failure); + pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast(answer); - builder.create( - loc, val, opNameAnswer->getValue().getStringRef(), success, failure); + pdl_interp::CheckOperationNameOp::create( + builder, loc, val, opNameAnswer->getValue().getStringRef(), success, + failure); break; } case Predicates::TypeQuestion: { auto *ans = cast(answer); if (isa(val.getType())) - builder.create( - loc, val, llvm::cast(ans->getValue()), success, failure); + pdl_interp::CheckTypesOp::create(builder, loc, val, + llvm::cast(ans->getValue()), + success, failure); else - builder.create( - loc, val, llvm::cast(ans->getValue()), success, failure); + pdl_interp::CheckTypeOp::create(builder, loc, val, + llvm::cast(ans->getValue()), + success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast(answer); - builder.create(loc, val, ans->getValue(), - success, failure); + pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(), + success, failure); break; } case Predicates::OperandCountAtLeastQuestion: case Predicates::OperandCountQuestion: - builder.create( - loc, val, cast(answer)->getValue(), + pdl_interp::CheckOperandCountOp::create( + builder, loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: - builder.create( - loc, val, cast(answer)->getValue(), + pdl_interp::CheckResultCountOp::create( + builder, loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, success, failure); break; case Predicates::EqualToQuestion: { bool trueAnswer = isa(answer); - builder.create(loc, val, args.front(), - trueAnswer ? success : failure, - trueAnswer ? failure : success); + pdl_interp::AreEqualOp::create(builder, loc, val, args.front(), + trueAnswer ? success : failure, + trueAnswer ? failure : success); break; } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - auto applyConstraintOp = builder.create( - loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, - cstQuestion->getIsNegated(), success, failure); + auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create( + builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(), + args, cstQuestion->getIsNegated(), success, failure); constraintOpMap.insert({cstQuestion, applyConstraintOp}); break; @@ -487,7 +492,7 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, blocks.push_back(it.second); values.push_back(cast(it.first)->getValue()); } - builder.create(val.getLoc(), val, values, defaultDest, blocks); + OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks); } void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, @@ -536,12 +541,14 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, unsigned ans = cast(child.first)->getValue(); switch (kind) { case Predicates::OperandCountAtLeastQuestion: - builder.create( - loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans, + /*compareAtLeast=*/true, + childBlock, defaultDest); break; case Predicates::ResultCountAtLeastQuestion: - builder.create( - loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + pdl_interp::CheckResultCountOp::create(builder, loc, val, ans, + /*compareAtLeast=*/true, + childBlock, defaultDest); break; default: llvm_unreachable("Generating invalid AtLeast operation"); @@ -619,8 +626,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); - auto matchOp = builder.create( - pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), + auto matchOp = pdl_interp::RecordMatchOp::create( + builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), failureBlockStack.back()); @@ -632,8 +639,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { SymbolRefAttr PatternLowering::generateRewriter( pdl::PatternOp pattern, SmallVectorImpl &usedMatchValues) { builder.setInsertionPointToEnd(rewriterModule.getBody()); - auto rewriterFunc = builder.create( - pattern.getLoc(), "pdl_generated_rewriter", + auto rewriterFunc = pdl_interp::FuncOp::create( + builder, pattern.getLoc(), "pdl_generated_rewriter", builder.getFunctionType({}, {})); rewriterSymbolTable.insert(rewriterFunc); @@ -651,18 +658,18 @@ SymbolRefAttr PatternLowering::generateRewriter( Operation *oldOp = oldValue.getDefiningOp(); if (pdl::AttributeOp attrOp = dyn_cast(oldOp)) { if (Attribute value = attrOp.getValueAttr()) { - return newValue = builder.create( - attrOp.getLoc(), value); + return newValue = pdl_interp::CreateAttributeOp::create( + builder, attrOp.getLoc(), value); } } else if (pdl::TypeOp typeOp = dyn_cast(oldOp)) { if (TypeAttr type = typeOp.getConstantTypeAttr()) { - return newValue = builder.create( - typeOp.getLoc(), type); + return newValue = pdl_interp::CreateTypeOp::create( + builder, typeOp.getLoc(), type); } } else if (pdl::TypesOp typeOp = dyn_cast(oldOp)) { if (ArrayAttr type = typeOp.getConstantTypesAttr()) { - return newValue = builder.create( - typeOp.getLoc(), typeOp.getType(), type); + return newValue = pdl_interp::CreateTypesOp::create( + builder, typeOp.getLoc(), typeOp.getType(), type); } } @@ -684,8 +691,9 @@ SymbolRefAttr PatternLowering::generateRewriter( auto mappedArgs = llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); - builder.create( - rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); + pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), + /*resultTypes=*/TypeRange(), rewriteName, + args); } else { // Otherwise this is a dag rewriter defined using PDL operations. for (Operation &rewriteOp : *rewriter.getBody()) { @@ -703,7 +711,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), /*results=*/{})); - builder.create(rewriter.getLoc()); + pdl_interp::FinalizeOp::create(builder, rewriter.getLoc()); return SymbolRefAttr::get( builder.getContext(), pdl_interp::PDLInterpDialect::getRewriterModuleName(), @@ -716,9 +724,9 @@ void PatternLowering::generateRewriter( SmallVector arguments; for (Value argument : rewriteOp.getArgs()) arguments.push_back(mapRewriteValue(argument)); - auto interpOp = builder.create( - rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), - arguments); + auto interpOp = pdl_interp::ApplyRewriteOp::create( + builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(), + rewriteOp.getNameAttr(), arguments); for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) rewriteValues[std::get<0>(it)] = std::get<1>(it); } @@ -726,16 +734,16 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::AttributeOp attrOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - Value newAttr = builder.create( - attrOp.getLoc(), attrOp.getValueAttr()); + Value newAttr = pdl_interp::CreateAttributeOp::create( + builder, attrOp.getLoc(), attrOp.getValueAttr()); rewriteValues[attrOp] = newAttr; } void PatternLowering::generateRewriter( pdl::EraseOp eraseOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - builder.create(eraseOp.getLoc(), - mapRewriteValue(eraseOp.getOpValue())); + pdl_interp::EraseOp::create(builder, eraseOp.getLoc(), + mapRewriteValue(eraseOp.getOpValue())); } void PatternLowering::generateRewriter( @@ -756,9 +764,9 @@ void PatternLowering::generateRewriter( // Create the new operation. Location loc = operationOp.getLoc(); - Value createdOp = builder.create( - loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, - attributes, operationOp.getAttributeValueNames()); + Value createdOp = pdl_interp::CreateOperationOp::create( + builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes, + operands, attributes, operationOp.getAttributeValueNames()); rewriteValues[operationOp.getOp()] = createdOp; // Generate accesses for any results that have their types constrained. @@ -768,8 +776,8 @@ void PatternLowering::generateRewriter( if (resultTys.size() == 1 && isa(resultTys[0].getType())) { Value &type = rewriteValues[resultTys[0]]; if (!type) { - auto results = builder.create(loc, createdOp); - type = builder.create(loc, results); + auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp); + type = pdl_interp::GetValueTypeOp::create(builder, loc, results); } return; } @@ -789,12 +797,13 @@ void PatternLowering::generateRewriter( // groups because the exact index of the result is not statically known. Value resultVal; if (seenVariableLength) - resultVal = builder.create( - loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); + resultVal = pdl_interp::GetResultsOp::create( + builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp, + it.index()); else - resultVal = builder.create( - loc, valueTy, createdOp, it.index()); - type = builder.create(loc, resultVal); + resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy, + createdOp, it.index()); + type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal); } } @@ -804,8 +813,8 @@ void PatternLowering::generateRewriter( SmallVector replOperands; for (Value operand : rangeOp.getArguments()) replOperands.push_back(mapRewriteValue(operand)); - rewriteValues[rangeOp] = builder.create( - rangeOp.getLoc(), rangeOp.getType(), replOperands); + rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create( + builder, rangeOp.getLoc(), rangeOp.getType(), replOperands); } void PatternLowering::generateRewriter( @@ -820,8 +829,8 @@ void PatternLowering::generateRewriter( // Don't use replace if we know the replaced operation has no results. auto opOp = replaceOp.getOpValue().getDefiningOp(); if (!opOp || !opOp.getTypeValues().empty()) { - replOperands.push_back(builder.create( - replOp.getLoc(), mapRewriteValue(replOp))); + replOperands.push_back(pdl_interp::GetResultsOp::create( + builder, replOp.getLoc(), mapRewriteValue(replOp))); } } else { for (Value operand : replaceOp.getReplValues()) @@ -830,29 +839,29 @@ void PatternLowering::generateRewriter( // If there are no replacement values, just create an erase instead. if (replOperands.empty()) { - builder.create( - replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); + pdl_interp::EraseOp::create(builder, replaceOp.getLoc(), + mapRewriteValue(replaceOp.getOpValue())); return; } - builder.create(replaceOp.getLoc(), - mapRewriteValue(replaceOp.getOpValue()), - replOperands); + pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(), + mapRewriteValue(replaceOp.getOpValue()), + replOperands); } void PatternLowering::generateRewriter( pdl::ResultOp resultOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - rewriteValues[resultOp] = builder.create( - resultOp.getLoc(), builder.getType(), + rewriteValues[resultOp] = pdl_interp::GetResultOp::create( + builder, resultOp.getLoc(), builder.getType(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } void PatternLowering::generateRewriter( pdl::ResultsOp resultOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - rewriteValues[resultOp] = builder.create( - resultOp.getLoc(), resultOp.getType(), + rewriteValues[resultOp] = pdl_interp::GetResultsOp::create( + builder, resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } @@ -863,7 +872,7 @@ void PatternLowering::generateRewriter( // type. if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { rewriteValues[typeOp] = - builder.create(typeOp.getLoc(), typeAttr); + pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr); } } @@ -873,8 +882,8 @@ void PatternLowering::generateRewriter( // If the type isn't constant, the users (e.g. OperationOp) will resolve this // type. if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { - rewriteValues[typeOp] = builder.create( - typeOp.getLoc(), typeOp.getType(), typeAttr); + rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create( + builder, typeOp.getLoc(), typeOp.getType(), typeAttr); } } @@ -939,10 +948,10 @@ void PatternLowering::generateOperationResultTypeRewriter( !replacedOp->isBeforeInBlock(op)) continue; - Value replacedOpResults = builder.create( - replacedOp->getLoc(), mapRewriteValue(replOpVal)); - types.push_back(builder.create( - replacedOp->getLoc(), replacedOpResults)); + Value replacedOpResults = pdl_interp::GetResultsOp::create( + builder, replacedOp->getLoc(), mapRewriteValue(replOpVal)); + types.push_back(pdl_interp::GetValueTypeOp::create( + builder, replacedOp->getLoc(), replacedOpResults)); return; } @@ -985,16 +994,18 @@ void PDLToPDLInterpPass::runOnOperation() { // Create the main matcher function This function contains all of the match // related functionality from patterns in the module. OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); - auto matcherFunc = builder.create( - module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), + auto matcherFunc = pdl_interp::FuncOp::create( + builder, module.getLoc(), + pdl_interp::PDLInterpDialect::getMatcherFunctionName(), builder.getFunctionType(builder.getType(), /*results=*/{}), /*attrs=*/ArrayRef()); // Create a nested module to hold the functions invoked for rewriting the IR // after a successful match. - ModuleOp rewriterModule = builder.create( - module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); + ModuleOp rewriterModule = + ModuleOp::create(builder, module.getLoc(), + pdl_interp::PDLInterpDialect::getRewriterModuleName()); // Generate the code for the patterns within the module. PatternLowering generator(matcherFunc, rewriterModule, configMap);