diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp index 70b22386f1eea..14fbb9bf09545 100644 --- a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp +++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp @@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create( - loc, rewriter.create(loc, rhsRe, rhsRe, fmf), - rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRe, fmf), - rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + Value realNumerator = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value imagNumerator = LLVM::FSubOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); - *resultRe = rewriter.create(loc, realNumerator, rhsSqNorm, fmf); - *resultIm = rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); + *resultRe = + LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); + *resultIm = + LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } void mlir::complex::convertDivToStandardUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create( - loc, rewriter.create(loc, rhsRe, rhsRe, fmf), - rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = arith::AddFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRe, fmf), - rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value realNumerator = arith::AddFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); + Value imagNumerator = arith::SubFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); *resultRe = - rewriter.create(loc, realNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); *resultIm = - rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } // Smith's algorithm to divide complex numbers. It is just a bit smarter @@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction( auto elementType = cast(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create( - loc, rhsIm, - rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), - lhsIm, fmf); - Value resultReal1 = - rewriter.create(loc, realNumerator1, rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), - lhsRe, fmf); - Value resultImag1 = - rewriter.create(loc, imagNumerator1, rhsRealImagDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = LLVM::FAddOp::create( + rewriter, loc, rhsIm, + LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = LLVM::FAddOp::create( + rewriter, loc, + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, + fmf); + Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1, + rhsRealImagDenom, fmf); + Value imagNumerator1 = LLVM::FSubOp::create( + rewriter, loc, + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, + fmf); + Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1, + rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create( - loc, rhsRe, - rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create( - loc, lhsRe, - rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); - Value resultReal2 = - rewriter.create(loc, realNumerator2, rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create( - loc, lhsIm, - rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); - Value resultImag2 = - rewriter.create(loc, imagNumerator2, rhsImagRealDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = LLVM::FAddOp::create( + rewriter, loc, rhsRe, + LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = LLVM::FAddOp::create( + rewriter, loc, lhsRe, + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2, + rhsImagRealDenom, fmf); + Value imagNumerator2 = LLVM::FSubOp::create( + rewriter, loc, lhsIm, + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2, + rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create( - loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); - Value lhsRealIsNotNaN = - rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsRe, zero); - Value lhsImagIsNotNaN = - rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsIm, zero); + Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType, + rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); + Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); + Value lhsRealIsNotNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero); + Value lhsImagIsNotNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( - loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( - loc, elementType, + LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = LLVM::AndOp::create( + rewriter, loc, lhsContainsNotNaNValue, + LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = LLVM::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfrhsReal = - rewriter.create(loc, inf, rhsRe); + LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create(loc, infWithSignOfrhsReal, lhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create(loc, infWithSignOfrhsReal, lhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create( - loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create( - loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); + Value rhsRealFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); + Value rhsImagFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); Value rhsFinite = - rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); + LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); + Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), - lhsRe); - Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), - lhsIm); + LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = LLVM::ConstantOp::create(rewriter, loc, elementType, + rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); + Value lhsImagIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesrhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesrhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesrhsReal, - lhsImagIsInfWithSignTimesrhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = LLVM::FMulOp::create( + rewriter, loc, inf, + LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal, + lhsImagIsInfWithSignTimesrhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesrhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesrhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesrhsReal, - lhsRealIsInfWithSignTimesrhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = LLVM::FMulOp::create( + rewriter, loc, inf, + LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal, + lhsRealIsInfWithSignTimesrhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create( - loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create( - loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); + Value lhsRealFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); + Value lhsImagFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); Value lhsFinite = - rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create( - loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); + LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); + Value rhsImagInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), - rhsRe); - Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), - rhsIm); + LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); + Value rhsImagIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimeslhsReal = - rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsImag = - rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimeslhsReal, - rhsImagIsInfWithSignTimeslhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = LLVM::FMulOp::create( + rewriter, loc, zero, + LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal, + rhsImagIsInfWithSignTimeslhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimeslhsImag = - rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsReal = - rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimeslhsImag, - rhsImagIsInfWithSignTimeslhsReal, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = LLVM::FMulOp::create( + rewriter, loc, zero, + LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag, + rhsImagIsInfWithSignTimeslhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); + Value resultReal5 = LLVM::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = LLVM::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = LLVM::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = LLVM::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = LLVM::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = LLVM::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = + LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = + LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create( - loc, LLVM::FCmpPredicate::uno, resultReal5, zero); - Value resultImagIsNaN = rewriter.create( - loc, LLVM::FCmpPredicate::uno, resultImag5, zero); + Value resultRealIsNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero); + Value resultImagIsNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create( - loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create( - loc, resultIsNaN, resultImagSpecialCase1, resultImag5); + *resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, + resultRealSpecialCase1, resultReal5); + *resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, + resultImagSpecialCase1, resultImag5); } void mlir::complex::convertDivToStandardUsingRangeReduction( @@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction( auto elementType = cast(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create( - loc, rhsIm, - rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), - lhsIm, fmf); - Value resultReal1 = rewriter.create(loc, realNumerator1, - rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), - lhsRe, fmf); - Value resultImag1 = rewriter.create(loc, imagNumerator1, - rhsRealImagDenom, fmf); + arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = arith::AddFOp::create( + rewriter, loc, rhsIm, + arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = arith::AddFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, + fmf); + Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1, + rhsRealImagDenom, fmf); + Value imagNumerator1 = arith::SubFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, + fmf); + Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1, + rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create( - loc, rhsRe, - rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create( - loc, lhsRe, - rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); - Value resultReal2 = rewriter.create(loc, realNumerator2, - rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create( - loc, lhsIm, - rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); - Value resultImag2 = rewriter.create(loc, imagNumerator2, - rhsImagRealDenom, fmf); + arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = arith::AddFOp::create( + rewriter, loc, rhsRe, + arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = arith::AddFOp::create( + rewriter, loc, lhsRe, + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2, + rhsImagRealDenom, fmf); + Value imagNumerator2 = arith::SubFOp::create( + rewriter, loc, lhsIm, + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2, + rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create( - loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = rewriter.create( - loc, arith::CmpFPredicate::ORD, lhsRe, zero); - Value lhsImagIsNotNaN = rewriter.create( - loc, arith::CmpFPredicate::ORD, lhsIm, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero); + Value lhsImagIsNotNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( - loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( - loc, elementType, + arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = arith::AndIOp::create( + rewriter, loc, lhsContainsNotNaNValue, + arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = arith::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = - rewriter.create(loc, inf, rhsRe); + math::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsRe, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsIm, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsRealFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = - rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), + arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); - Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), + Value lhsImagIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = arith::MulFOp::create( + rewriter, loc, inf, + arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = arith::MulFOp::create( + rewriter, loc, inf, + arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsRealFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = - rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), + arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); - Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), + Value rhsImagIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = arith::MulFOp::create( + rewriter, loc, zero, + arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal, fmf), + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = arith::MulFOp::create( + rewriter, loc, zero, + arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + Value realAbsSmallerThanImagAbs = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value resultReal5 = arith::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = arith::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = arith::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = arith::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = arith::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = arith::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = + arith::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = + arith::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create( - loc, arith::CmpFPredicate::UNO, resultReal5, zero); - Value resultImagIsNaN = rewriter.create( - loc, arith::CmpFPredicate::UNO, resultImag5, zero); + Value resultRealIsNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero); + Value resultImagIsNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create( - loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create( - loc, resultIsNaN, resultImagSpecialCase1, resultImag5); + *resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN, + resultRealSpecialCase1, resultReal5); + *resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN, + resultImagSpecialCase1, resultImag5); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index e5e862315941d..86d02e6c6209f 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder, Location loc, Type type) { - Value val = builder.create(loc, type); + Value val = LLVM::PoisonOp::create(builder, loc, type); return ComplexStructBuilder(val); } @@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value sqNorm = rewriter.create( - loc, rewriter.create(loc, real, real, fmf), - rewriter.create(loc, imag, imag, fmf), fmf); + Value sqNorm = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf), + LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); @@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(), + arg.rhs.real(), fmf); + Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(), + arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern { Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); - Value real = rewriter.create( - loc, rewriter.create(loc, rhsRe, lhsRe, fmf), - rewriter.create(loc, rhsIm, lhsIm, fmf), fmf); + Value real = LLVM::FSubOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf); - Value imag = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value imag = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(), + arg.rhs.real(), fmf); + Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(), + arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp index 56269d189873a..f83cac751ff05 100644 --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall::matchAndRewrite( rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), name, - opFunctionTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, + opFunctionTy); opFunc.setPrivate(); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 99d5424aef79a..6f0fc2965e6fd 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern { rewriter.setInsertionPointToStart(&symTable->getRegion(0).front()); auto funcTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, - funcTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), + funcName, funcTy); opFunc.setPrivate(); } rewriter.replaceOpWithNewOp(op, funcName, op.getType(), diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0c832c452718b..eeff8a93e7a72 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -31,44 +31,45 @@ enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { - Value one = b.create(real.getType(), - b.getFloatAttr(real.getType(), 1.0)); + Value one = arith::ConstantOp::create(b, real.getType(), + b.getFloatAttr(real.getType(), 1.0)); - Value absReal = b.create(real, fmf); - Value absImag = b.create(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value max = b.create(absReal, absImag, fmf); - Value min = b.create(absReal, absImag, fmf); + Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf); + Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value ratio = b.create(min, max, fmfWithNaNInf); - Value ratioSq = b.create(ratio, ratio, fmfWithNaNInf); - Value ratioSqPlusOne = b.create(ratioSq, one, fmfWithNaNInf); + Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf); + Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf); + Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create(ratioSqPlusOne, fmfWithNaNInf); - min = b.create(min, fmfWithNaNInf); - max = b.create(max, fmfWithNaNInf); + ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + min = math::RsqrtOp::create(b, min, fmfWithNaNInf); + max = math::RsqrtOp::create(b, max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { - Value quarter = b.create( - real.getType(), b.getFloatAttr(real.getType(), 0.25)); + Value quarter = arith::ConstantOp::create( + b, real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. - Value sqrt = b.create(max, fmfWithNaNInf); - Value p025 = b.create(ratioSqPlusOne, quarter, fmfWithNaNInf); - result = b.create(sqrt, p025, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf); + Value p025 = + math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf); + result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf); } else { - Value sqrt = b.create(ratioSqPlusOne, fmfWithNaNInf); - result = b.create(max, sqrt, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf); } - Value isNaN = b.create(arith::CmpFPredicate::UNO, result, - result, fmfWithNaNInf); - return b.create(isNaN, min, result); + Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result, + result, fmfWithNaNInf); + return arith::SelectOp::create(b, isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern { @@ -81,8 +82,8 @@ struct AbsOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); @@ -105,28 +106,28 @@ struct Atan2OpConversion : public OpConversionPattern { Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); - Value rhsSquared = b.create(type, rhs, rhs, fmf); - Value lhsSquared = b.create(type, lhs, lhs, fmf); + Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf); + Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = - b.create(type, rhsSquared, lhsSquared, fmf); + complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = - b.create(type, rhsSquaredPlusLhsSquared, fmf); + complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf); Value zero = - b.create(elementType, b.getZeroAttr(elementType)); - Value one = b.create(elementType, - b.getFloatAttr(elementType, 1)); - Value i = b.create(type, zero, one); - Value iTimesLhs = b.create(i, lhs, fmf); - Value rhsPlusILhs = b.create(rhs, iTimesLhs, fmf); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value i = complex::CreateOp::create(b, type, zero, one); + Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf); + Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf); - Value divResult = b.create( - rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); - Value logResult = b.create(divResult, fmf); + Value divResult = complex::DivOp::create( + b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); + Value logResult = complex::LogOp::create(b, divResult, fmf); - Value negativeOne = b.create( - elementType, b.getFloatAttr(elementType, -1)); - Value negativeI = b.create(type, zero, negativeOne); + Value negativeOne = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -1)); + Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne); rewriter.replaceOpWithNewOp(op, negativeI, logResult, fmf); return success(); @@ -146,14 +147,18 @@ struct ComparisonOpConversion : public OpConversionPattern { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()).getElementType(); - Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); - Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); - Value realRhs = rewriter.create(loc, type, adaptor.getRhs()); - Value imagRhs = rewriter.create(loc, type, adaptor.getRhs()); + Value realLhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getLhs()); + Value imagLhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getLhs()); + Value realRhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getRhs()); + Value imagRhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getRhs()); Value realComparison = - rewriter.create(loc, p, realLhs, realRhs); + arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs); Value imagComparison = - rewriter.create(loc, p, imagLhs, imagRhs); + arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); @@ -176,14 +181,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value realLhs = b.create(elementType, adaptor.getLhs()); - Value realRhs = b.create(elementType, adaptor.getRhs()); - Value resultReal = b.create(elementType, realLhs, realRhs, - fmf.getValue()); - Value imagLhs = b.create(elementType, adaptor.getLhs()); - Value imagRhs = b.create(elementType, adaptor.getRhs()); - Value resultImag = b.create(elementType, imagLhs, imagRhs, - fmf.getValue()); + Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value resultReal = BinaryStandardOp::create(b, elementType, realLhs, + realRhs, fmf.getValue()); + Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs()); + Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs, + imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -205,20 +210,20 @@ struct TrigonometricOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. - Value half = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); - Value exp = rewriter.create(loc, imag, fmf); - Value scaledExp = rewriter.create(loc, half, exp, fmf); - Value reciprocalExp = rewriter.create(loc, half, exp, fmf); - Value sin = rewriter.create(loc, real, fmf); - Value cos = rewriter.create(loc, real, fmf); + Value half = arith::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); + Value exp = math::ExpOp::create(rewriter, loc, imag, fmf); + Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf); + Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf); + Value sin = math::SinOp::create(rewriter, loc, real, fmf); + Value cos = math::CosOp::create(rewriter, loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); @@ -251,11 +256,11 @@ struct CosOpConversion : public TrigonometricOpConversion { // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = - rewriter.create(loc, reciprocalExp, scaledExp, fmf); - Value resultReal = rewriter.create(loc, sum, cos, fmf); + arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf); Value diff = - rewriter.create(loc, reciprocalExp, scaledExp, fmf); - Value resultImag = rewriter.create(loc, diff, sin, fmf); + arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf); return {resultReal, resultImag}; } }; @@ -275,13 +280,13 @@ struct DivOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = - rewriter.create(loc, elementType, adaptor.getLhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value lhsImag = - rewriter.create(loc, elementType, adaptor.getLhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value rhsReal = - rewriter.create(loc, elementType, adaptor.getRhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value rhsImag = - rewriter.create(loc, elementType, adaptor.getRhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value resultReal, resultImag; @@ -318,16 +323,16 @@ struct ExpOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value expReal = rewriter.create(loc, real, fmf.getValue()); - Value cosImag = rewriter.create(loc, imag, fmf.getValue()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); + Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); Value resultReal = - rewriter.create(loc, expReal, cosImag, fmf.getValue()); - Value sinImag = rewriter.create(loc, imag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); + Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); Value resultImag = - rewriter.create(loc, expReal, sinImag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -340,11 +345,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast(arg.getType()); Value poly = - b.create(b.getFloatAttr(argType, coefficients[0])); + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { - poly = b.create( - poly, arg, - b.create(b.getFloatAttr(argType, coefficients[i])), + poly = math::FmaOp::create( + b, poly, arg, + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; @@ -365,26 +370,26 @@ struct Expm1OpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value zero = b.create(b.getFloatAttr(elemType, 0.0)); - Value one = b.create(b.getFloatAttr(elemType, 1.0)); + Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0)); + Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0)); - Value expm1Real = b.create(real, fmf); - Value expReal = b.create(expm1Real, one, fmf); + Value expm1Real = math::ExpM1Op::create(b, real, fmf); + Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf); - Value sinImag = b.create(imag, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); - Value cosImag = b.create(cosm1Imag, one, fmf); + Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf); - Value realResult = b.create( - b.create(expm1Real, cosImag, fmf), cosm1Imag, fmf); + Value realResult = arith::AddFOp::create( + b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf); - Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, - zero, fmf.getValue()); - Value imagResult = b.create( - imagIsZero, zero, b.create(expReal, sinImag, fmf)); + Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, + zero, fmf.getValue()); + Value imagResult = arith::SelectOp::create( + b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp(op, type, realResult, imagResult); @@ -395,8 +400,8 @@ struct Expm1OpConversion : public OpConversionPattern { Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast(arg.getType()); - auto negHalf = b.create(b.getFloatAttr(argType, -0.5)); - auto negOne = b.create(b.getFloatAttr(argType, -1.0)); + auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5)); + auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector kCoeffs{ @@ -405,23 +410,23 @@ struct Expm1OpConversion : public OpConversionPattern { 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; - Value cos = b.create(arg, fmf); - Value forLargeArg = b.create(cos, negOne, fmf); + Value cos = math::CosOp::create(b, arg, fmf); + Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf); - Value argPow2 = b.create(arg, arg, fmf); - Value argPow4 = b.create(argPow2, argPow2, fmf); + Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf); + Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = - b.create(b.create(argPow4, poly, fmf), - b.create(negHalf, argPow2, fmf)); + arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf), + arith::MulFOp::create(b, negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = - b.create(b.getFloatAttr(argType, 0.61685)); - Value cond = b.create(arith::CmpFPredicate::OGE, argPow2, - piOver4Pow2, fmf.getValue()); - return b.create(cond, forLargeArg, forSmallArg); + arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685)); + Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2, + piOver4Pow2, fmf.getValue()); + return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg); } }; @@ -436,13 +441,13 @@ struct LogOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create(elementType, adaptor.getComplex(), - fmf.getValue()); - Value resultReal = b.create(elementType, abs, fmf.getValue()); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), + fmf.getValue()); + Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value resultImag = - b.create(elementType, imag, real, fmf.getValue()); + math::Atan2Op::create(b, elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -460,40 +465,42 @@ struct Log1pOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value half = b.create(elementType, - b.getFloatAttr(elementType, 0.5)); - Value one = b.create(elementType, - b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one, fmf); - Value absRealPlusOne = b.create(realPlusOne, fmf); - Value absImag = b.create(imag, fmf); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = arith::AddFOp::create(b, real, one, fmf); + Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value maxAbs = b.create(absRealPlusOne, absImag, fmf); - Value minAbs = b.create(absRealPlusOne, absImag, fmf); + Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf); + Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf); - Value useReal = b.create(arith::CmpFPredicate::OGT, - realPlusOne, absImag, fmf); - Value maxMinusOne = b.create(maxAbs, one, fmf); + Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, + realPlusOne, absImag, fmf); + Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = - b.create(useReal, real, maxMinusOne); + arith::SelectOp::create(b, useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value minMaxRatio = b.create(minAbs, maxAbs, fmfWithNaNInf); + Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = - b.create(maxAbsOfRealPlusOneAndImagMinusOne, fmf); - Value logOfSqrtPart = b.create( - b.create(minMaxRatio, minMaxRatio, fmfWithNaNInf), + math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf); + Value logOfSqrtPart = math::Log1pOp::create( + b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); - Value r = b.create( - b.create(half, logOfSqrtPart, fmfWithNaNInf), + Value r = arith::AddFOp::create( + b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); - Value resultReal = b.create( - b.create(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), + Value resultReal = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r, + fmfWithNaNInf), minAbs, r); - Value resultImag = b.create(imag, realPlusOne, fmf); + Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -511,22 +518,22 @@ struct MulOpConversion : public OpConversionPattern { auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); - Value lhsReal = b.create(elementType, adaptor.getLhs()); - Value lhsImag = b.create(elementType, adaptor.getLhs()); - Value rhsReal = b.create(elementType, adaptor.getRhs()); - Value rhsImag = b.create(elementType, adaptor.getRhs()); + Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = - b.create(lhsReal, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = - b.create(lhsImag, rhsImag, fmfValue); - Value real = b.create(lhsRealTimesRhsReal, - lhsImagTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue); + Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal, + lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = - b.create(lhsImag, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = - b.create(lhsReal, rhsImag, fmfValue); - Value imag = b.create(lhsImagTimesRhsReal, - lhsRealTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue); + Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, + lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); } @@ -543,11 +550,11 @@ struct NegOpConversion : public OpConversionPattern { auto elementType = cast(type.getElementType()); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value negReal = rewriter.create(loc, real); - Value negImag = rewriter.create(loc, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negReal = arith::NegFOp::create(rewriter, loc, real); + Value negImag = arith::NegFOp::create(rewriter, loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } @@ -570,11 +577,11 @@ struct SinOpConversion : public TrigonometricOpConversion { // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = - rewriter.create(loc, scaledExp, reciprocalExp, fmf); - Value resultReal = rewriter.create(loc, sum, sin, fmf); + arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf); Value diff = - rewriter.create(loc, scaledExp, reciprocalExp, fmf); - Value resultImag = rewriter.create(loc, diff, cos, fmf); + arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf); return {resultReal, resultImag}; } }; @@ -593,64 +600,65 @@ struct SqrtOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); - Value half = b.create(elementType, - b.getFloatAttr(elementType, 0.5)); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); - Value argArg = b.create(imag, real, fmf); - Value sqrtArg = b.create(argArg, half, fmf); - Value cos = b.create(sqrtArg, fmf); - Value sin = b.create(sqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf); + Value cos = math::CosOp::create(b, sqrtArg, fmf); + Value sin = math::SinOp::create(b, sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = - b.create(arith::CmpFPredicate::OEQ, sin, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf); - Value resultReal = b.create(absSqrt, cos, fmf); - Value resultImag = b.create( - sinIsZero, zero, b.create(absSqrt, sin, fmf)); + Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf); + Value resultImag = arith::SelectOp::create( + b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absImag = b.create(elementType, imag, fmf); + Value absImag = math::AbsFOp::create(b, elementType, imag, fmf); - Value absImagIsInf = - b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); - Value absImagIsNotInf = - b.create(arith::CmpFPredicate::ONE, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); + Value absImagIsNotInf = arith::CmpFOp::create( + b, arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = - b.create(arith::CmpFPredicate::OEQ, real, inf, fmf); - Value realIsNegInf = - b.create(arith::CmpFPredicate::OEQ, real, negInf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf); + Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + real, negInf, fmf); - resultReal = b.create( - b.create(realIsNegInf, absImagIsNotInf), zero, + resultReal = arith::SelectOp::create( + b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero, resultReal); - resultReal = b.create( - b.create(absImagIsInf, realIsInf), inf, resultReal); + resultReal = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal); - Value imagSignInf = b.create(inf, imag, fmf); - resultImag = b.create( - b.create(arith::CmpFPredicate::UNO, absSqrt, absSqrt), + Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf); + resultImag = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); - resultImag = b.create( - b.create(absImagIsInf, realIsNegInf), imagSignInf, + resultImag = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = - b.create(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); - resultReal = b.create(resultIsZero, zero, resultReal); - resultImag = b.create(resultIsZero, zero, resultImag); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); + resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -669,19 +677,20 @@ struct SignOpConversion : public OpConversionPattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value zero = - b.create(elementType, b.getZeroAttr(elementType)); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value realIsZero = - b.create(arith::CmpFPredicate::OEQ, real, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create(realIsZero, imagIsZero); - auto abs = b.create(elementType, adaptor.getComplex(), fmf); - Value realSign = b.create(real, abs, fmf); - Value imagSign = b.create(imag, abs, fmf); - Value sign = b.create(type, realSign, imagSign); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero); + auto abs = + complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf); + Value realSign = arith::DivFOp::create(b, real, abs, fmf); + Value imagSign = arith::DivFOp::create(b, imag, abs, fmf); + Value sign = complex::CreateOp::create(b, type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); return success(); @@ -703,84 +712,84 @@ struct TanTanhOpConversion : public OpConversionPattern { const auto &floatSemantics = elementType.getFloatSemantics(); Value real = - b.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(b, loc, elementType, adaptor.getComplex()); Value imag = - b.create(loc, elementType, adaptor.getComplex()); - Value negOne = b.create( - elementType, b.getFloatAttr(elementType, -1.0)); + complex::ImOp::create(b, loc, elementType, adaptor.getComplex()); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); - real = b.create(real, negOne, fmf); + real = arith::MulFOp::create(b, real, negOne, fmf); } auto cst = [&](APFloat v) { - return b.create(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); - Value four = b.create(elementType, - b.getFloatAttr(elementType, 4.0)); - Value twoReal = b.create(real, real, fmf); - Value negTwoReal = b.create(negOne, twoReal, fmf); - - Value expTwoRealMinusOne = b.create(twoReal, fmf); - Value expNegTwoRealMinusOne = b.create(negTwoReal, fmf); - Value realNum = - b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); - - Value cosImag = b.create(imag, fmf); - Value cosImagSq = b.create(cosImag, cosImag, fmf); - Value twoCosTwoImagPlusOne = b.create(cosImagSq, four, fmf); - Value sinImag = b.create(imag, fmf); - - Value imagNum = b.create( - four, b.create(cosImag, sinImag, fmf), fmf); - - Value expSumMinusTwo = - b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); + Value four = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 4.0)); + Value twoReal = arith::AddFOp::create(b, real, real, fmf); + Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf); + + Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf); + Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf); + Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); + + Value cosImag = math::CosOp::create(b, imag, fmf); + Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf); + Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); + + Value imagNum = arith::MulFOp::create( + b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf); + + Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); Value denom = - b.create(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); + arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf); - Value isInf = b.create(arith::CmpFPredicate::OEQ, - expSumMinusTwo, inf, fmf); - Value realLimit = b.create(negOne, real, fmf); + Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + expSumMinusTwo, inf, fmf); + Value realLimit = math::CopySignOp::create(b, negOne, real, fmf); - Value resultReal = b.create( - isInf, realLimit, b.create(realNum, denom, fmf)); - Value resultImag = b.create(imagNum, denom, fmf); + Value resultReal = arith::SelectOp::create( + b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf)); + Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value absReal = b.create(real, fmf); - Value zero = b.create( - elementType, b.getFloatAttr(elementType, 0.0)); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value zero = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absRealIsInf = - b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); + Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); Value imagIsZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value absRealIsNotInf = b.create( - absRealIsInf, b.create(true, /*width=*/1)); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value absRealIsNotInf = arith::XOrIOp::create( + b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1)); - Value imagNumIsNaN = b.create(arith::CmpFPredicate::UNO, - imagNum, imagNum, fmf); + Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, + imagNum, imagNum, fmf); Value resultRealIsNaN = - b.create(imagNumIsNaN, absRealIsNotInf); - Value resultImagIsZero = b.create( - imagIsZero, b.create(absRealIsInf, imagNumIsNaN)); + arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf); + Value resultImagIsZero = arith::OrIOp::create( + b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN)); - resultReal = b.create(resultRealIsNaN, nan, resultReal); + resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal); resultImag = - b.create(resultImagIsZero, zero, resultImag); + arith::SelectOp::create(b, resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); - resultImag = b.create(resultImag, negOne, fmf); + resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp(op, type, resultReal, @@ -799,10 +808,10 @@ struct ConjOpConversion : public OpConversionPattern { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value negImag = rewriter.create(loc, elementType, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag); rewriter.replaceOpWithNewOp(op, type, real, negImag); @@ -818,97 +827,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, arith::FastMathFlags fmf) { auto elementType = cast(type.getElementType()); - Value a = builder.create(lhs); - Value b = builder.create(lhs); + Value a = complex::ReOp::create(builder, lhs); + Value b = complex::ImOp::create(builder, lhs); - Value abs = builder.create(lhs, fmf); - Value absToC = builder.create(abs, c, fmf); + Value abs = complex::AbsOp::create(builder, lhs, fmf); + Value absToC = math::PowFOp::create(builder, abs, c, fmf); - Value negD = builder.create(d, fmf); - Value argLhs = builder.create(b, a, fmf); - Value negDArgLhs = builder.create(negD, argLhs, fmf); - Value expNegDArgLhs = builder.create(negDArgLhs, fmf); + Value negD = arith::NegFOp::create(builder, d, fmf); + Value argLhs = math::Atan2Op::create(builder, b, a, fmf); + Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf); + Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf); - Value coeff = builder.create(absToC, expNegDArgLhs, fmf); - Value lnAbs = builder.create(abs, fmf); - Value cArgLhs = builder.create(c, argLhs, fmf); - Value dLnAbs = builder.create(d, lnAbs, fmf); - Value q = builder.create(cArgLhs, dLnAbs, fmf); - Value cosQ = builder.create(q, fmf); - Value sinQ = builder.create(q, fmf); + Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf); + Value lnAbs = math::LogOp::create(builder, abs, fmf); + Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf); + Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf); + Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf); + Value cosQ = math::CosOp::create(builder, q, fmf); + Value sinQ = math::SinOp::create(builder, q, fmf); - Value inf = builder.create( - elementType, + Value inf = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value zero = builder.create( - elementType, builder.getFloatAttr(elementType, 0.0)); - Value one = builder.create( - elementType, builder.getFloatAttr(elementType, 1.0)); - Value complexOne = builder.create(type, one, zero); - Value complexZero = builder.create(type, zero, zero); - Value complexInf = builder.create(type, inf, zero); + Value zero = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, 0.0)); + Value one = arith::ConstantOp::create(builder, elementType, + builder.getFloatAttr(elementType, 1.0)); + Value complexOne = complex::CreateOp::create(builder, type, one, zero); + Value complexZero = complex::CreateOp::create(builder, type, zero, zero); + Value complexInf = complex::CreateOp::create(builder, type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = - builder.create(arith::CmpFPredicate::OEQ, abs, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = - builder.create(arith::CmpFPredicate::OEQ, d, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = - builder.create(arith::CmpFPredicate::OEQ, c, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = - builder.create(arith::CmpFPredicate::OEQ, b, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = - builder.create(arith::CmpFPredicate::OLE, zero, c, fmf); - Value coeffCosQ = builder.create(coeff, cosQ, fmf); - Value coeffSinQ = builder.create(coeff, sinQ, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf); + Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf); + Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf); Value complexOneOrZero = - builder.create(cEqZero, complexOne, complexZero); + arith::SelectOp::create(builder, cEqZero, complexOne, complexZero); Value coeffCosSin = - builder.create(type, coeffCosQ, coeffSinQ); - Value cutoff0 = builder.create( - builder.create( - builder.create(absEqZero, dEqZero), zeroLeC), + complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ); + Value cutoff0 = arith::SelectOp::create( + builder, + arith::AndIOp::create( + builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - Value rhsEqZero = builder.create(cEqZero, dEqZero); + Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero); Value cutoff1 = - builder.create(rhsEqZero, complexOne, cutoff0); + arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i - Value lhsEqOne = builder.create( - builder.create(arith::CmpFPredicate::OEQ, a, one, fmf), + Value lhsEqOne = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = - builder.create(lhsEqOne, complexOne, cutoff1); + arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 - Value lhsEqInf = builder.create( - builder.create(arith::CmpFPredicate::OEQ, a, inf, fmf), + Value lhsEqInf = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); - Value rhsGt0 = builder.create( - dEqZero, - builder.create(arith::CmpFPredicate::OGT, c, zero, fmf)); - Value cutoff3 = builder.create( - builder.create(lhsEqInf, rhsGt0), complexInf, cutoff2); + Value rhsGt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf)); + Value cutoff3 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf, + cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 - Value rhsLt0 = builder.create( - dEqZero, - builder.create(arith::CmpFPredicate::OLT, c, zero, fmf)); - Value cutoff4 = builder.create( - builder.create(lhsEqInf, rhsLt0), complexZero, cutoff3); + Value rhsLt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf)); + Value cutoff4 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero, + cutoff3); return cutoff4; } @@ -923,8 +937,8 @@ struct PowOpConversion : public OpConversionPattern { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); - Value c = builder.create(elementType, adaptor.getRhs()); - Value d = builder.create(elementType, adaptor.getRhs()); + Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs()); + Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); @@ -945,64 +959,64 @@ struct RsqrtOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create( - elementType, b.getFloatAttr(elementType, -0.5)); + Value negHalf = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create(imag, real, fmf); - Value rsqrtArg = b.create(argArg, negHalf, fmf); - Value cos = b.create(rsqrtArg, fmf); - Value sin = b.create(rsqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf); + Value cos = math::CosOp::create(b, rsqrtArg, fmf); + Value sin = math::SinOp::create(b, rsqrtArg, fmf); - Value resultReal = b.create(absRsqrt, cos, fmf); - Value resultImag = b.create(absRsqrt, sin, fmf); + Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf); + Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value negOne = b.create( - elementType, b.getFloatAttr(elementType, -1)); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1)); - Value realSignedZero = b.create(zero, real, fmf); - Value imagSignedZero = b.create(zero, imag, fmf); + Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf); + Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf); Value negImagSignedZero = - b.create(negOne, imagSignedZero, fmf); + arith::MulFOp::create(b, negOne, imagSignedZero, fmf); - Value absReal = b.create(real, fmf); - Value absImag = b.create(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value absImagIsInf = - b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); Value realIsNan = - b.create(arith::CmpFPredicate::UNO, real, real, fmf); - Value realIsInf = - b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create(absImagIsInf, realIsNan); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf); + Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); + Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan); - Value resultIsZero = b.create(inIsNanInf, realIsInf); + Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf); resultReal = - b.create(resultIsZero, realSignedZero, resultReal); - resultImag = b.create(resultIsZero, negImagSignedZero, - resultImag); + arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero, + resultImag); } Value isRealZero = - b.create(arith::CmpFPredicate::OEQ, real, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create(isRealZero, isImagZero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero); - resultReal = b.create(isZero, inf, resultReal); - resultImag = b.create(isZero, nan, resultImag); + resultReal = arith::SelectOp::create(b, isZero, inf, resultReal); + resultImag = arith::SelectOp::create(b, isZero, nan, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -1021,9 +1035,9 @@ struct AngleOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, type, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, type, adaptor.getComplex()); Value imag = - rewriter.create(loc, type, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, imag, real, fmf); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 13a084407e53f..ff6d369176393 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), - "abort", abortFuncTy); + abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(), + "abort", abortFuncTy); } - rewriter.create(loc, abortFunc, ValueRange()); - rewriter.create(loc); + LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange()); + LLVM::UnreachableOp::create(rewriter, loc); } else { - rewriter.create(loc, ValueRange(), continuationBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock); } // Generate assertion test. diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index 9831dcaaaccc8..c8311eb5a6433 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( MutableArrayRef regions) { if (auto condBrOp = dyn_cast(controlFlowCondOp)) { assert(regions.size() == 2); - auto ifOp = builder.create(controlFlowCondOp->getLoc(), - resultTypes, condBrOp.getCondition()); + auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(), + resultTypes, condBrOp.getCondition()); ifOp.getThenRegion().takeBody(regions[0]); ifOp.getElseRegion().takeBody(regions[1]); return ifOp.getOperation(); @@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( if (auto switchOp = dyn_cast(controlFlowCondOp)) { // `getCFGSwitchValue` returns an i32 that we need to convert to index // fist. - auto cast = builder.create( - controlFlowCondOp->getLoc(), builder.getIndexType(), + auto cast = arith::IndexCastUIOp::create( + builder, controlFlowCondOp->getLoc(), builder.getIndexType(), switchOp.getFlag()); SmallVector cases; if (auto caseValues = switchOp.getCaseValues()) @@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( assert(regions.size() == cases.size() + 1); - auto indexSwitchOp = builder.create( - controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); + auto indexSwitchOp = + scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(), + resultTypes, cast, cases, cases.size()); indexSwitchOp.getDefaultRegion().takeBody(regions[0]); for (auto &&[targetRegion, sourceRegion] : @@ -75,7 +76,7 @@ LogicalResult ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) { - builder.create(loc, results); + scf::YieldOp::create(builder, loc, results); return success(); } @@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { Location loc = replacedOp->getLoc(); - auto whileOp = builder.create(loc, loopVariablesInit.getTypes(), - loopVariablesInit); + auto whileOp = scf::WhileOp::create( + builder, loc, loopVariablesInit.getTypes(), loopVariablesInit); whileOp.getBefore().takeBody(loopBody); builder.setInsertionPointToEnd(&whileOp.getBefore().back()); // `getCFGSwitchValue` returns a i32. We therefore need to truncate the // condition to i1 first. It is guaranteed to be either 0 or 1 already. - builder.create( - loc, builder.create(loc, builder.getI1Type(), condition), + scf::ConditionOp::create( + builder, loc, + arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition), loopVariablesNextIter); Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector(loopVariablesInit.size(), loc)); - builder.create(loc, afterBlock->getArguments()); + scf::YieldOp::create(builder, loc, afterBlock->getArguments()); return whileOp.getOperation(); } @@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned int value) { - return builder.create(loc, - builder.getI32IntegerAttr(value)); + return arith::ConstantOp::create(builder, loc, + builder.getI32IntegerAttr(value)); } void ControlFlowToSCFTransformation::createCFGSwitchOp( @@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp( ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseArguments, Block *defaultDest, ValueRange defaultArgs) { - builder.create(loc, flag, defaultDest, defaultArgs, - llvm::to_vector_of(caseValues), - caseDestinations, caseArguments); + cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs, + llvm::to_vector_of(caseValues), + caseDestinations, caseArguments); } Value ControlFlowToSCFTransformation::getUndefValue(Location loc, OpBuilder &builder, Type type) { - return builder.create(loc, type, nullptr); + return ub::PoisonOp::create(builder, loc, type, nullptr); } FailureOr diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index f8dc06f41ab87..197caeb4ffbfa 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -99,8 +99,8 @@ class FuncOpConversion final : public OpConversionPattern { } // Create the converted `emitc.func` op. - emitc::FuncOp newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), + emitc::FuncOp newFuncOp = emitc::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), FunctionType::get(rewriter.getContext(), signatureConverter.getConvertedTypes(), resultType ? TypeRange(resultType) : TypeRange())); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 36235636d6ba2..67bb1c14c99a2 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, SmallVector attributes; filterFuncAttributes(funcOp, attributes); - auto wrapperFuncOp = rewriter.create( - loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + auto wrapperFuncOp = LLVM::LLVMFuncOp::create( + rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); @@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); if (auto memrefType = dyn_cast(argType)) { - Value loaded = rewriter.create( - loc, typeConverter.convertType(memrefType), arg); + Value loaded = LLVM::LoadOp::create( + rewriter, loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (isa(argType)) { - Value loaded = rewriter.create( - loc, typeConverter.convertType(argType), arg); + Value loaded = LLVM::LoadOp::create( + rewriter, loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } @@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, args.push_back(arg); } - auto call = rewriter.create(loc, newFuncOp, args); + auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args); if (resultStructType) { - rewriter.create(loc, call.getResult(), - wrapperFuncOp.getArgument(0)); - rewriter.create(loc, ValueRange{}); + LLVM::StoreOp::create(rewriter, loc, call.getResult(), + wrapperFuncOp.getArgument(0)); + LLVM::ReturnOp::create(rewriter, loc, ValueRange{}); } else { - rewriter.create(loc, call.getResults()); + LLVM::ReturnOp::create(rewriter, loc, call.getResults()); } } @@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, filterFuncAttributes(funcOp, attributes); // Create the auxiliary function. - auto wrapperFunc = builder.create( - loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + auto wrapperFunc = LLVM::LLVMFuncOp::create( + builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc); @@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, if (resultStructType) { // Allocate the struct on the stack and pass the pointer. Type resultType = cast(wrapperType).getParamType(0); - Value one = builder.create( - loc, typeConverter.convertType(builder.getIndexType()), + Value one = LLVM::ConstantOp::create( + builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value result = - builder.create(loc, resultType, resultStructType, one); + LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one); args.push_back(result); } @@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create( - loc, typeConverter.convertType(builder.getIndexType()), + Value one = LLVM::ConstantOp::create( + builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); - Value allocated = builder.create( - loc, ptrTy, packed.getType(), one, /*alignment=*/0); - builder.create(loc, packed, allocated); + Value allocated = LLVM::AllocaOp::create( + builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0); + LLVM::StoreOp::create(builder, loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; @@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); - auto call = builder.create(loc, wrapperFunc, args); + auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args); if (resultStructType) { Value result = - builder.create(loc, resultStructType, args.front()); - builder.create(loc, result); + LLVM::LoadOp::create(builder, loc, resultStructType, args.front()); + LLVM::ReturnOp::create(builder, loc, result); } else { - builder.create(loc, call.getResults()); + LLVM::ReturnOp::create(builder, loc, call.getResults()); } } @@ -283,7 +283,7 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast(byValRefAttr->getValue()).getValue()); - Value valueArg = rewriter.create(arg.getLoc(), resTy, arg); + Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); rewriter.replaceUsesOfBlockArgument(arg, valueArg); } } @@ -357,8 +357,8 @@ FailureOr mlir::convertFuncOpToLLVMFuncOp( symbolTable.remove(funcOp); } - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), llvmType, linkage, + auto newFuncOp = LLVM::LLVMFuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "failed to convert result type"); auto newOp = - rewriter.create(op.getLoc(), type, op.getValue()); + LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue()); for (const NamedAttribute &attr : op->getAttrs()) { if (attr.getName().strref() == "value") continue; @@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter, useBarePtrCallConv); - auto newOp = rewriter.create( - callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promoted, callOp->getAttrs()); + auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(), + packedResult ? TypeRange(packedResult) + : TypeRange(), + promoted, callOp->getAttrs()); newOp.getProperties().operandSegmentSizes = { static_cast(promoted.size()), 0}; @@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( - callOp.getLoc(), newOp->getResult(0), i)); + results.push_back(LLVM::ExtractValueOp::create( + rewriter, callOp.getLoc(), newOp->getResult(0), i)); } } @@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 01ca5e99a9aff..1037e296c8128 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, if (!(ret = moduleOp.template lookupSymbol(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); - ret = b.create(loc, name, type, LLVM::Linkage::External); + ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } @@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); - return b.create(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); + return LLVM::GlobalOp::create(b, loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); } LogicalResult @@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - auto globalOp = rewriter.create( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = LLVM::GlobalOp::create( + rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); workgroupBuffers.push_back(globalOp); @@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = LLVM::LLVMFuncOp::create( + rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); @@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); - Value address = rewriter.create( - loc, ptrType, global.getSymNameAttr()); + Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, + global.getSymNameAttr()); Value memory = - rewriter.create(loc, ptrType, global.getType(), - address, ArrayRef{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), + address, ArrayRef{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value numElements = LLVM::ConstantOp::create( + rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - Value allocated = rewriter.create( - gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); + Value allocated = + LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, + elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall - Value zeroI64 = rewriter.create(loc, llvmI64, 0); - auto printfBeginCall = rewriter.create(loc, ocklBegin, zeroI64); + Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); + auto printfBeginCall = + LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. @@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() - Value globalPtr = rewriter.create( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef{0, 0}); - Value stringLen = rewriter.create( - loc, llvmI64, cast(global.getValueAttr()).size()); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef{0, 0}); + Value stringLen = LLVM::ConstantOp::create( + rewriter, loc, llvmI64, cast(global.getValueAttr()).size()); - Value oneI32 = rewriter.create(loc, llvmI32, 1); - Value zeroI32 = rewriter.create(loc, llvmI32, 0); + Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); + Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); - auto appendFormatCall = rewriter.create( - loc, ocklAppendStringN, + auto appendFormatCall = LLVM::CallOp::create( + rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(); @@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( SmallVector arguments; arguments.push_back(printfDesc); arguments.push_back( - rewriter.create(loc, llvmI32, numArgsThisCall)); + LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast(arg.getType())) { if (!floatType.isF64()) - arg = rewriter.create( - loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create(loc, llvmI64, arg); + arg = LLVM::FPExtOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), + arg); + arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create(loc, llvmI64, arg); + arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } @@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create(loc, ocklAppendArgs, arguments); + auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); @@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef{0, 0}); // Construct arguments and function call auto argsRange = adaptor.getArgs(); @@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - rewriter.create(loc, printfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create(loc, global); + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef{0, 0}); SmallVector types; SmallVector args; // Promote and pack the arguments into a stack allocation. @@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( assert(type.isIntOrFloat()); if (isa(type)) { type = rewriter.getF64Type(); - promotedArg = rewriter.create(loc, type, arg); + promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Value one = rewriter.create(loc, rewriter.getI64Type(), - rewriter.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); Value tempAlloc = - rewriter.create(loc, ptrType, structType, one, - /*alignment=*/0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, + /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { - Value ptr = rewriter.create( - loc, ptrType, structType, tempAlloc, + Value ptr = LLVM::GEPOp::create( + rewriter, loc, ptrType, structType, tempAlloc, ArrayRef{0, static_cast(index)}); - rewriter.create(loc, arg, ptr); + LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array printfArgs = {stringStart, tempAlloc}; - rewriter.create(loc, vprintfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, TypeRange operandTypes(operands); VectorType vectorType = cast(llvm1DVectorTy); Location loc = op->getLoc(); - Value result = rewriter.create(loc, vectorType); + Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { - Value index = rewriter.create(loc, indexType, i); + Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa(operand.getType())) return operand; - return rewriter.create(loc, operand, index); + return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - result = rewriter.create( - loc, result, scalarOp->getResult(0), index); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + scalarOp->getResult(0), index); } return result; } @@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol( auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); - return rewriter.create( - op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, - LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, - addressSpace.value()); + return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, + /*isConstant=*/false, LLVM::Linkage::Internal, + symName, /*value=*/Attribute(), alignmentByte, + addressSpace.value()); } LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( @@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - auto basePtr = rewriter.create(loc, shmemOp); + auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector gepArgs = {0}; - Value shmemPtr = rewriter.create(loc, baseType, elementType, - basePtr, gepArgs); + Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, + basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector shape, strides; Value sizeBytes; @@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 167cabbc57db9..63eb6c58e87a7 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -79,8 +79,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern { uint64_t rank = type.getRank(); Value numElements = desc.size(rewriter, loc, /*pos=*/0); for (unsigned i = 1; i < rank; i++) - numElements = rewriter.create( - loc, numElements, desc.size(rewriter, loc, /*pos=*/i)); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, + desc.size(rewriter, loc, /*pos=*/i)); return numElements; } @@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, return OpBuilder::atBlockEnd(module.getBody()) .create(loc, functionName, functionType); }(); - return builder.create(loc, function, arguments); + return LLVM::CallOp::create(builder, loc, function, arguments); } // Corresponding to cusparseIndexType_t defined in cusparse.h. @@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); - auto isHostShared = rewriter.create( - loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast(bitwidth / 8) * static_cast(memrefTy.getNumElements()); - Value sizeArg = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create( - launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::LaunchFuncOp::create( + rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc, const LLVMTypeConverter &typeConverter) { auto sourceTy = cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create( - loc, + sourcePtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), destinationType.getAddressSpace()), sourcePtr); @@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create( + rewriter, loc, elementPtrType, typeConverter->convertType(memRefType.getElementType()), nullPtr, numElements); auto sizeBytes = - rewriter.create(loc, getIndexType(), gepPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcDesc.alignedPtr(rewriter, loc), @@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create(loc, bitCastType, adaptor.getValue()); + LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); @@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( template static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { Type llvmInt32Type = builder.getIntegerType(32); - return builder.create(loc, llvmInt32Type, - static_cast(tValue)); + return LLVM::ConstantOp::create(builder, loc, llvmInt32Type, + static_cast(tValue)); } template static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { Type llvmFloat32Type = builder.getF32Type(); - return builder.create( - loc, llvmFloat32Type, + return LLVM::ConstantOp::create( + builder, loc, llvmFloat32Type, builder.getF32FloatAttr(static_cast(tValue))); } @@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( // the dnmat is used with spmat with 2:4 sparsity if (dims.size() == 2) { if (isSpMMCusparseLtOp(op.getDnTensor())) { - auto handleSz = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(11032)); - handle = rewriter.create( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(11032)); + handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); createLtDnMatCallBuilder .create(loc, rewriter, @@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); // CUDA runner asserts the size is 44104 bytes. - auto handleSz = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(44104)); - Value handle = rewriter.create( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(44104)); + Value handle = LLVM::AllocaOp::create( + rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); create2To4SpMatCallBuilder .create(loc, rewriter, @@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); auto computeType = genConstInt32From( rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); - auto three = rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto bufferSize = rewriter.create( - loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto bufferSize = + LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType, + three, /*alignment=*/16); createCuSparseLtSpMMBufferSizeBuilder .create(loc, rewriter, {bufferSize, modeA, modeB, adaptor.getSpmatA(), @@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( pruneFlag, stream}) .getResult(); - auto bufferSizePtr1 = rewriter.create( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto bufferSizePtr2 = rewriter.create( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(2))}); + auto bufferSizePtr1 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto bufferSizePtr2 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); auto bufferSize0 = - rewriter.create(loc, llvmInt64Type, bufferSize); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize); auto bufferSize1 = - rewriter.create(loc, llvmInt64Type, bufferSizePtr1); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1); auto bufferSize2 = - rewriter.create(loc, llvmInt64Type, bufferSizePtr2); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2); rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { @@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op.getLoc(); auto stream = adaptor.getAsyncDependencies().front(); - auto three = rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto buffer = rewriter.create( - loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); - - auto rowsPtr = rewriter.create( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(0))}); - auto colsPtr = rewriter.create( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(1))}); - auto nnzsPtr = rewriter.create( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), - rewriter.getIndexAttr(2))}); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt64Type, three, /*alignment=*/16); + + auto rowsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0))}); + auto colsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto nnzsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); createSpMatGetSizeBuilder.create( loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); - auto rows = rewriter.create(loc, llvmInt64Type, rowsPtr); - auto cols = rewriter.create(loc, llvmInt64Type, colsPtr); - auto nnzs = rewriter.create(loc, llvmInt64Type, nnzsPtr); + auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr); + auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr); + auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr); rewriter.replaceOp(op, {rows, cols, nnzs, stream}); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index aab2409ed6328..91c43e8bd1117 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -59,13 +59,13 @@ struct OpLowering : public ConvertOpToLLVMPattern { Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); break; } @@ -124,11 +124,13 @@ struct OpLowering : public ConvertOpToLLVMPattern { rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::SExtOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::TruncOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } rewriter.replaceOp(op, newOp->getResults()); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 64cf09e600b88..9f36e5c369d06 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -103,7 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = - rewriter.create(op->getLoc(), funcOp, castedOperands); + LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands); if (resultType == adaptor.getOperands().front().getType()) { rewriter.replaceOp(op, {callOp.getResult()}); @@ -115,19 +115,20 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { // there is no guarantee of a specific value being used to indicate true, // compare for inequality with zero (rather than truncate or shift). if (isResultBool) { - Value zero = rewriter.create( - op->getLoc(), rewriter.getIntegerType(32), - rewriter.getI32IntegerAttr(0)); - Value truncated = rewriter.create( - op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getIntegerType(32), + rewriter.getI32IntegerAttr(0)); + Value truncated = + LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne, + callOp.getResult(), zero); rewriter.replaceOp(op, {truncated}); return success(); } assert(callOp.getResult().getType().isF32() && "only f32 types are supposed to be truncated back"); - Value truncated = rewriter.create( - op->getLoc(), adaptor.getOperands().front().getType(), + Value truncated = LLVM::FPTruncOp::create( + rewriter, op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); rewriter.replaceOp(op, {truncated}); return success(); @@ -142,8 +143,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { if (!f16Func.empty() && isa(type)) return operand; - return rewriter.create( - operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), + Float32Type::get(rewriter.getContext()), + operand); } Type getFunctionType(Type resultType, ValueRange operands) const { @@ -169,7 +171,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { // location as debug info metadata inside of a function cannot be used // outside of that function. auto globalloc = op->getLoc()->findInstanceOfOrUnknown(); - return b.create(globalloc, funcName, funcType); + return LLVMFuncOp::create(b, globalloc, funcName, funcType); } StringRef getFunctionName(Type type, SourceOp op) const { diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 8b6b553f6eed0..c2363a1a40294 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, SymbolTable::lookupSymbolIn(symbolTable, name)); if (!func) { OpBuilder b(symbolTable->getRegion(0)); - func = b.create( - symbolTable->getLoc(), name, + func = LLVM::LLVMFuncOp::create( + b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); func.setNoUnwind(true); @@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = rewriter.create(loc, func, args); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern { constexpr int64_t localMemFenceFlag = 1; Location loc = op->getLoc(); Value flag = - rewriter.create(loc, flagTy, localMemFenceFlag); + LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); return success(); } @@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern { Location loc = op->getLoc(); gpu::Dimension dim = getDimension(op); - Value dimVal = rewriter.create(loc, dimTy, - static_cast(dim)); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast(dim)); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); return success(); } @@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) { return TypeSwitch(oldVal.getType()) .Case([&](BFloat16Type) { - return rewriter.create(loc, rewriter.getI16Type(), - oldVal); + return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(), + oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create(loc, rewriter.getI8Type(), - oldVal); + return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(), + oldVal); return oldVal; }) .Default(oldVal); @@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) { return TypeSwitch(newTy) .Case([&](BFloat16Type) { - return rewriter.create(loc, newTy, oldVal); + return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create(loc, newTy, oldVal); + return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal); return oldVal; }) .Default(oldVal); @@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); Value trueVal = - rewriter.create(loc, rewriter.getI1Type(), true); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true); rewriter.replaceOp(op, {resultOrConversion, trueVal}); return success(); } @@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern { if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { return failure(); } - result = rewriter.create(loc, indexTy, result); + result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 1ef6edea93c58..317bfc2970cf5 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering Location loc = op->getLoc(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value offset = rewriter.create(loc, int32Type, -1); + Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); - auto reduxOp = rewriter.create(loc, int32Type, op.getValue(), - mode.value(), offset); + auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type, + op.getValue(), mode.value(), offset); rewriter.replaceOp(op, reduxOp->getResult(0)); return success(); @@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = rewriter.create(loc, int32Type, 1); - Value minusOne = rewriter.create(loc, int32Type, -1); - Value thirtyTwo = rewriter.create(loc, int32Type, 32); - Value numLeadInactiveLane = rewriter.create( - loc, int32Type, thirtyTwo, adaptor.getWidth()); + Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); + Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); + Value numLeadInactiveLane = LLVM::SubOp::create( + rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = rewriter.create(loc, int32Type, minusOne, - numLeadInactiveLane); + Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, + numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { // Clamp lane: `32 - activeWidth` maskAndClamp = numLeadInactiveLane; } else { // Clamp lane: `activeWidth - 1` - maskAndClamp = - rewriter.create(loc, int32Type, adaptor.getWidth(), one); + maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type, + adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = rewriter.create( - loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), - maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); + Value shfl = NVVM::ShflOp::create( + rewriter, loc, resultTy, activeMask, adaptor.getValue(), + adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), + returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = rewriter.create(loc, shfl, 0); + Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); Value isActiveSrcLane = - rewriter.create(loc, shfl, 1); + LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { bounds = rewriter.getAttr( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - rewriter.create(loc, rewriter.getI32Type(), bounds); + NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); return success(); @@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering Block *afterBlock = rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); rewriter.setInsertionPointToEnd(beforeBlock); - rewriter.create(loc, adaptor.getArg(), afterBlock, - assertBlock); + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock, + assertBlock); rewriter.setInsertionPointToEnd(assertBlock); - rewriter.create(loc, afterBlock); + cf::BranchOp::create(rewriter, loc, afterBlock); // Continue cf.assert lowering. rewriter.setInsertionPoint(assertOp); @@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering // Create constants. auto getGlobal = [&](LLVM::GlobalOp global) { // Get a pointer to the format string's first element. - Value globalPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), global.getSymNameAttr()); Value start = - rewriter.create(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef{0, 0}); return start; }; Value assertMessage = getGlobal(getOrCreateStringConstant( @@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering Value assertFunc = getGlobal(getOrCreateStringConstant( rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); Value assertLine = - rewriter.create(loc, i32Type, fileLine); - Value c1 = rewriter.create(loc, i64Type, 1); + LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1); // Insert function call to __assertfail. SmallVector arguments{assertMessage, assertFile, assertLine, diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 45fd933d58857..99c059cb03299 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create( - loc, rewriter.getI32Type(), + Value leadingDim = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp( op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); @@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering auto matrixType = cast(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = - rewriter.create(loc, adaptor.getSrc(), i); + LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i); storeOpOperands.push_back(toUse); } @@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering rewriter, loc, cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create( - loc, rewriter.getI32Type(), + Value leadingDim = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp( op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); @@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering auto unpackOp = [&](Value operand) { auto structType = cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { - Value toUse = rewriter.create(loc, operand, i); + Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); unpackedOps.push_back(toUse); } }; @@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering cast(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. if (auto vecType = dyn_cast(type.getBody()[0])) { - Value vecCst = rewriter.create(loc, vecType); + Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { - Value idx = rewriter.create( - loc, rewriter.getI32Type(), vecEl); - vecCst = rewriter.create(loc, vecType, vecCst, - cst, idx); + Value idx = LLVM::ConstantOp::create(rewriter, loc, + rewriter.getI32Type(), vecEl); + vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst, + cst, idx); } cst = vecCst; } - Value matrixStruct = rewriter.create(loc, type); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { matrixStruct = - rewriter.create(loc, matrixStruct, cst, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); return success(); @@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Type i1Type = builder.getI1Type(); if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); - Value cmp = builder.create( - loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, - lhs, rhs); - Value sel = builder.create(loc, cmp, lhs, rhs); - Value isNan = builder.create( - loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); - Value nan = builder.create( - loc, lhs.getType(), + Value cmp = LLVM::FCmpOp::create( + builder, loc, i1Type, + isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs); + Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs); + Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type, + LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = LLVM::ConstantOp::create( + builder, loc, lhs.getType(), builder.getFloatAttr(floatType, APFloat::getQNaN(floatType.getFloatSemantics()))); - return builder.create(loc, isNan, nan, sel); + return LLVM::SelectOp::create(builder, loc, isNan, nan, sel); } static Value createScalarOp(OpBuilder &builder, Location loc, @@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc, ArrayRef operands) { switch (op) { case gpu::MMAElementwiseOp::ADDF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MULF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::DIVF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MAXF: return createMinMaxF(builder, loc, operands[0], operands[1], /*isMin=*/false); @@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( cast(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = rewriter.create(loc, destType); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { - extractedOperands.push_back(rewriter.create( - loc, adaptor.getOperands()[opIdx], i)); + extractedOperands.push_back(LLVM::ExtractValueOp::create( + rewriter, loc, adaptor.getOperands()[opIdx], i)); } Value element = createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), extractedOperands); matrixStruct = - rewriter.create(loc, matrixStruct, element, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i); } rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); return success(); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 456bfaba980ca..d22364e1ef441 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); // TODO: use <=> in C++20. if (indexBitwidth > intWidth) { - return rewriter.create(loc, indexBitwidthType, value); + return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value); } if (indexBitwidth < intWidth) { - return rewriter.create(loc, indexBitwidthType, value); + return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value); } return value; } @@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth) { auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value zero = rewriter.create(loc, 0, 32); - Value minus1 = rewriter.create(loc, -1, 32); - Value mbcntLo = rewriter.create(loc, int32Type, - ValueRange{minus1, zero}); - Value laneId = rewriter.create(loc, int32Type, - ValueRange{minus1, mbcntLo}); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, + ValueRange{minus1, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, + ValueRange{minus1, mbcntLo}); return laneId; } static constexpr StringLiteral amdgcnDataLayout = @@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) Type intTy = IntegerType::get(context, 32); - Value zero = rewriter.create(loc, 0, 32); - Value minus1 = rewriter.create(loc, -1, 32); - Value mbcntLo = - rewriter.create(loc, intTy, ValueRange{minus1, zero}); - Value laneId = rewriter.create( - loc, intTy, ValueRange{minus1, mbcntLo}); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy, + ValueRange{minus1, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy, + ValueRange{minus1, mbcntLo}); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - laneId = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } else if (indexBitwidth < 32) { - laneId = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } rewriter.replaceOp(op, {laneId}); return success(); @@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern { /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32, /*upper=*/op.getUpperBoundAttr().getInt() + 1); } - Value wavefrontOp = rewriter.create( - op.getLoc(), rewriter.getI32Type(), bounds); + Value wavefrontOp = ROCDL::WavefrontSizeOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), bounds); wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp, *getTypeConverter()); rewriter.replaceOp(op, {wavefrontOp}); @@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value width = adaptor.getWidth(); - Value zero = rewriter.create(loc, int32Type, 0); - Value negwidth = rewriter.create(loc, int32Type, zero, width); - Value add = rewriter.create(loc, int32Type, srcLaneId, width); + Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0); + Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width); + Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width); Value widthOrZeroIfOutside = - rewriter.create(loc, int32Type, add, negwidth); + LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth); Value dstLane; switch (op.getMode()) { case gpu::ShuffleMode::UP: - dstLane = rewriter.create(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::DOWN: - dstLane = rewriter.create(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::XOR: - dstLane = rewriter.create(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::IDX: dstLane = adaptor.getOffset(); break; } - Value isActiveSrcLane = rewriter.create( - loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); - Value selectDstLane = rewriter.create(loc, isActiveSrcLane, - dstLane, srcLaneId); - Value two = rewriter.create(loc, int32Type, 2); + Value isActiveSrcLane = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); + Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane, + dstLane, srcLaneId); + Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2); Value dwordAlignedDstLane = - rewriter.create(loc, int32Type, selectDstLane, two); + LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two); SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type); SmallVector swizzled; for (Value v : decomposed) { - Value res = rewriter.create(loc, int32Type, - dwordAlignedDstLane, v); + Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, + dwordAlignedDstLane, v); swizzled.emplace_back(res); } Value shflValue = diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index b99ed261ecfa3..a19194eb181fb 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion::matchAndRewrite( Value vector = spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); - Value dim = rewriter.create( - op.getLoc(), builtinType, vector, + Value dim = spirv::CompositeExtractOp::create( + rewriter, op.getLoc(), builtinType, vector, rewriter.getI32ArrayAttr({static_cast(op.getDimension())})); if (forShader && builtinType != indexType) - dim = rewriter.create(op.getLoc(), indexType, dim); + dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim); rewriter.replaceOp(op, dim); return success(); } @@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion::matchAndRewrite( Value builtinValue = spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); if (i32Type != indexType) - builtinValue = rewriter.create(op.getLoc(), indexType, - builtinValue); + builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, + builtinValue); rewriter.replaceOp(op, builtinValue); return success(); } @@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), + auto newFuncOp = spirv::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {})); for (const auto &namedAttr : funcOp->getAttrs()) { if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || @@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite( // Add a keyword to the module name to avoid symbolic conflict. std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); - auto spvModule = rewriter.create( - moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, + auto spvModule = spirv::ModuleOp::create( + rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, StringRef(spvModuleName)); // Move the region from the module op into the SPIR-V module. @@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( switch (shuffleOp.getMode()) { case gpu::ShuffleMode::XOR: { - result = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleXorOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::IDX: { - result = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::DOWN: { - result = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleDownOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create(loc, laneId, adaptor.getOffset()); - validVal = rewriter.create(loc, arith::CmpIPredicate::ult, - resultLaneId, adaptor.getWidth()); + arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset()); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + resultLaneId, adaptor.getWidth()); break; } case gpu::ShuffleMode::UP: { - result = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleUpOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create(loc, laneId, adaptor.getOffset()); + arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); - validVal = rewriter.create( - loc, arith::CmpIPredicate::sge, resultLaneId, - rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); + validVal = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, resultLaneId, + arith::ConstantOp::create(rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, 0))); break; } } @@ -516,15 +516,16 @@ LogicalResult GPURotateConversion::matchAndRewrite( Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); - Value rotateResult = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth()); + Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), + adaptor.getWidth()); Value validVal; if (widthAttr.getValue().getZExtValue() == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { - Value laneId = rewriter.create(loc, widthAttr); - validVal = rewriter.create(loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + laneId, adaptor.getWidth()); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -548,14 +549,14 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, ? spirv::GroupOperation::ClusteredReduce : spirv::GroupOperation::Reduce); if (isUniform) { - return builder.create(loc, type, scope, groupOp, arg) + return UniformOp::create(builder, loc, type, scope, groupOp, arg) .getResult(); } Value clusterSizeValue; if (clusterSize.has_value()) - clusterSizeValue = builder.create( - loc, builder.getI32Type(), + clusterSizeValue = spirv::ConstantOp::create( + builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); return builder @@ -740,8 +741,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstName = makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); - return rewriter.create( - loc, rewriter.getStringAttr(specCstName), attr); + return spirv::SpecConstantOp::create( + rewriter, loc, rewriter.getStringAttr(specCstName), attr); }; { Operation *parent = @@ -774,8 +775,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstCompositeName = (llvm::Twine(globalVarName) + "_scc").str(); - specCstComposite = rewriter.create( - loc, TypeAttr::get(globalType), + specCstComposite = spirv::SpecConstantCompositeOp::create( + rewriter, loc, TypeAttr::get(globalType), rewriter.getStringAttr(specCstCompositeName), rewriter.getArrayAttr(constituents)); @@ -785,23 +786,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( // Define a GlobalVarOp initialized using specialized constants // that is used to specify the printf format string // to be passed to the SPIRV CLPrintfOp. - globalVar = rewriter.create( - loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); + globalVar = spirv::GlobalVariableOp::create( + rewriter, loc, ptrType, globalVarName, + FlatSymbolRefAttr::get(specCstComposite)); globalVar->setAttr("Constant", rewriter.getUnitAttr()); } // Get SSA value of Global variable and create pointer to i8 to point to // the format string. - Value globalPtr = rewriter.create(loc, globalVar); - Value fmtStr = rewriter.create( - loc, + Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar); + Value fmtStr = spirv::BitcastOp::create( + rewriter, loc, spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), globalPtr); // Get printf arguments. auto printfArgs = llvm::to_vector_of(adaptor.getArgs()); - rewriter.create(loc, i32Type, fmtStr, printfArgs); + spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs); // Need to erase the gpu.printf op as gpu.printf does not use result vs // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index 0b2c06a08db2d..a344f88326089 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() { if (targetEnvSupportsKernelCapability(moduleOp)) { moduleOp.walk([&](gpu::GPUFuncOp funcOp) { builder.setInsertionPoint(funcOp); - auto newFuncOp = builder.create( - funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + auto newFuncOp = + func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(), + funcOp.getFunctionType()); auto entryBlock = newFuncOp.addEntryBlock(); builder.setInsertionPointToEnd(entryBlock); - builder.create(funcOp.getLoc()); + func::ReturnOp::create(builder, funcOp.getLoc()); newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); funcOp.erase(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 7bb86b5ce1ddd..51dc50048024f 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( - loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto strideValue = spirv::ConstantOp::create( + rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor @@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( - loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto strideValue = spirv::ConstantOp::create( + rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor