diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index bafeca924e4c5..8d45c40a93e2b 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2435,6 +2435,7 @@ def VectorizeOp : Op:$static_vector_sizes, OptionalAttr:$vectorize_nd_extract, OptionalAttr:$assume_dynamic_dims_match_vec_sizes, + OptionalAttr:$create_named_contraction, DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 9e62d0dcc7890..38e53648e7c34 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -876,12 +876,15 @@ struct VectorizationResult { /// greater than or equal to their counterpart iteration space sizes, if static. /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. +/// Optionally, `createNamedContraction` can force compatible contractions to be +/// vectorized directly to vector.contract operation. FailureOr vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false, - bool assumeDynamicDimsMatchVecSizes = false); + bool assumeDynamicDimsMatchVecSizes = false, + bool createNamedContraction = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index cc8421b23a074..9b765d0b8ede6 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type); /// Note: all read offsets are set to 0. Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef inputVectorSizes, Value padValue, - bool useInBoundsInsteadOfMasking = false); + bool useInBoundsInsteadOfMasking = false, + ArrayRef scalableDims = {}); /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index c959310136319..109e5b7f95ec0 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3920,8 +3920,10 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( } FailureOr vectorResults = linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), - getVectorizeNdExtract().value_or(false), false, - getAssumeDynamicDimsMatchVecSizes().value_or(false)); + getVectorizeNdExtract().value_or(false), + /*flatten1DDepthwiseConv=*/false, + getAssumeDynamicDimsMatchVecSizes().value_or(false), + getCreateNamedContraction().value_or(false)); if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 4add50f4b36e5..77c85abab9aa0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, return write; // Compute the mask and mask the write Op. - auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); + auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(), + vecToStoreType.getScalableDims()); SmallVector destSizes = - tensor::getMixedSizes(builder, loc, dest); + isa(dest.getType()) + ? memref::getMixedSizes(builder, loc, dest) + : tensor::getMixedSizes(builder, loc, dest); SmallVector maskSizes(destSizes.end() - vecToStoreRank, destSizes.end()); @@ -2118,6 +2122,92 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, return success(); } +/// Vectorize a named linalg contraction op into: +/// vector::TransferReadOp - Reads vectors from the operands +/// vector::ContractionOp - Performs contraction +/// vector::TransferWriteOp - Write the result vector back to the +/// destination +/// The operands shapes are preserved and loaded directly into vectors. +/// Any further permutations or numerical casting remain within contraction op. +static LogicalResult +vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, + LinalgOp linalgOp, + SmallVectorImpl &newResults) { + Location loc = linalgOp.getLoc(); + MLIRContext *ctx = linalgOp.getContext(); + + // For simplicity, contraction vectorization is limited to linalg named ops. + // Generic op is ignored as not every arbitrary contraction body can be + // expressed by a vector.contract. + if (!isa(linalgOp.getOperation())) + return failure(); + + OpOperand *outOperand = linalgOp.getDpsInitOperand(0); + Operation *reduceOp = matchLinalgReduction(outOperand); + auto maybeKind = getCombinerOpKind(reduceOp); + if (!maybeKind) { + LDBG("Failed to determine contraction combining kind.\n"); + return failure(); + } + + // Check that all dimensions are present in the input operands. + // Arbitrary broadcasts are not supported by the vector contraction. + // Broadcasts are expected to be decomposed before vectorization. + AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; + AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; + if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) { + LDBG("Contractions with broadcasts are not supported.\n"); + return failure(); + } + + // Load operands. + SmallVector vecOperands; + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + // The operand vector shape is computed by mapping the canonical vector + // shape to the operand's domain. Further permutations are left as a part of + // the contraction. + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); + AffineMap readMap = AffineMap::getMultiDimIdentityMap( + indexingMap.getNumResults(), rewriter.getContext()); + Type elemType = getElementTypeOrSelf(opOperand.get()); + VectorType readType = + state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); + + Value read = mlir::vector::createReadOrMaskedRead( + rewriter, loc, opOperand.get(), readType.getShape(), + /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), + /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + vecOperands.push_back(read); + } + + // Remap iterators from linalg to vector. + SmallVector iterAttrs; + auto iterators = linalgOp.getIteratorTypesArray(); + for (utils::IteratorType iter : iterators) { + auto vecIter = iter == utils::IteratorType::parallel + ? vector::IteratorType::parallel + : vector::IteratorType::reduction; + iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter)); + } + + // Create contraction. + Operation *contractOp = rewriter.create( + loc, /*lhs=*/vecOperands[0], + /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], + linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); + contractOp = state.maskOperation(rewriter, contractOp, linalgOp); + + // Store result. + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, contractOp->getResult(0), outOperand->get()); + + // Finalize. + if (!write->getResults().empty()) + newResults.push_back(write->getResult(0)); + + return success(); +} + namespace { enum class ConvOperationKind { Conv, Pool }; } // namespace @@ -2557,7 +2647,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { FailureOr mlir::linalg::vectorize( RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, ArrayRef inputScalableVecDims, bool vectorizeNDExtract, - bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) { + bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes, + bool createNamedContraction) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2604,6 +2695,11 @@ FailureOr mlir::linalg::vectorize( return failure(); } + if (createNamedContraction && + isa(linalgOp.getOperation())) + return vectorizeAsLinalgContraction(rewriter, state, linalgOp, + results); + LDBG("Vectorize generic by broadcasting to the canonical vector " "shape\n"); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 7e4984582b373..9b055853fc8b0 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef inputVectorSizes, Value padValue, - bool useInBoundsInsteadOfMasking) { + bool useInBoundsInsteadOfMasking, + ArrayRef scalableDims) { assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast(source.getType()); auto sourceShape = sourceShapedType.getShape(); assert(sourceShape.size() == inputVectorSizes.size() && "expected same ranks."); - auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); + auto vectorType = + VectorType::get(inputVectorSizes, padValue.getType(), scalableDims); assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); @@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) return transferReadOp; SmallVector mixedSourceDims = - tensor::getMixedSizes(builder, loc, source); + isa(source.getType()) + ? memref::getMixedSizes(builder, loc, source) + : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type()); + auto maskType = + VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims); Value mask = builder.create(loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir new file mode 100644 index 0000000000000..d8f897cca958d --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir @@ -0,0 +1,484 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +///---------------------------------------------------------------------------------------- +/// Tests for vectorizing operations implementing contraction op interface. +/// Ops implementing the contraction interface are vectorized directly to their +/// vector dialect named counterparts. +///---------------------------------------------------------------------------------------- + +func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) + +/// Get the contraction dimensions +// CHECK: %[[MATMUL_DIM_M_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[MATMUL_DIM_M:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_M_IDX]] : tensor +// CHECK: %[[MATMUL_DIM_N_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[MATMUL_DIM_N:.*]] = tensor.dim %[[B]], %[[MATMUL_DIM_N_IDX]] : tensor +// CHECK: %[[MATMUL_DIM_K_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[MATMUL_DIM_K:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_K_IDX]] : tensor + +/// Create a mask for the A matrix +// CHECK: %[[A_OFFSET:.*]] = arith.constant 0 : index +// CHECK: %[[A_DIM_M_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[A_DIM_M:.*]] = tensor.dim %[[A]], %[[A_DIM_M_IDX]] : tensor +// CHECK: %[[A_DIM_K_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[A_DIM_K:.*]] = tensor.dim %[[A]], %[[A_DIM_K_IDX]] : tensor +// CHECK: %[[LOAD_A_MASK:.*]] = vector.create_mask +// CHECK-SAME: %[[A_DIM_M]], %[[A_DIM_K]] : vector<8x4xi1> +/// Read the A matrix +// CHECK: %[[LOAD_A:.*]] = vector.mask %[[LOAD_A_MASK]] +// CHECK-SAME: { vector.transfer_read %[[A]]{{\[}}%[[A_OFFSET]], %[[A_OFFSET]]{{\]}} +// CHECK-SAME: : tensor, vector<8x4xf32> } +// CHECK-SAME: : vector<8x4xi1> -> vector<8x4xf32> + +/// Create a mask for the B matrix +// CHECK: %[[B_OFFSET:.*]] = arith.constant 0 : index +// CHECK: %[[B_DIM_K_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[B_DIM_K:.*]] = tensor.dim %[[B]], %[[B_DIM_K_IDX]] : tensor +// CHECK: %[[B_DIM_N_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[B_DIM_N:.*]] = tensor.dim %[[B]], %[[B_DIM_N_IDX]] : tensor +// CHECK: %[[LOAD_B_MASK:.*]] = vector.create_mask +// CHECK-SAME: %[[B_DIM_K]], %[[B_DIM_N]] : vector<4x16xi1> +/// Read the B matrix +// CHECK: %[[LOAD_B:.*]] = vector.mask %[[LOAD_B_MASK]] +// CHECK-SAME: { vector.transfer_read %[[B]]{{\[}}%[[B_OFFSET]], %[[B_OFFSET]]{{\]}} +// CHECK-SAME: : tensor, vector<4x16xf32> } +// CHECK-SAME: : vector<4x16xi1> -> vector<4x16xf32> + +/// Create a mask for the C matrix +// CHECK: %[[C_OFFSET:.*]] = arith.constant 0 : index +// CHECK: %[[C_DIM_M_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[C_DIM_M:.*]] = tensor.dim %[[C]], %[[C_DIM_M_IDX]] : tensor +// CHECK: %[[C_DIM_N_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[C_DIM_N:.*]] = tensor.dim %[[C]], %[[C_DIM_N_IDX]] : tensor +// CHECK: %[[LOAD_C_MASK:.*]] = vector.create_mask +// CHECK-SAME: %[[C_DIM_M]], %[[C_DIM_N]] : vector<8x16xi1> +/// Read the C matrix +// CHECK: %[[LOAD_C:.*]] = vector.mask %[[LOAD_C_MASK]] +// CHECK-SAME: { vector.transfer_read %[[C]]{{\[}}%[[C_OFFSET]], %[[C_OFFSET]]{{\]}} +// CHECK-SAME: : tensor, vector<8x16xf32> } +// CHECK-SAME: : vector<8x16xi1> -> vector<8x16xf32> + +/// Create a mask for the contraction +// CHECK: %[[CONTRACTION_MASK:.*]] = vector.create_mask +// CHECK-SAME: %[[MATMUL_DIM_M]], %[[MATMUL_DIM_N]], %[[MATMUL_DIM_K]] +// CHECK-SAME: : vector<8x16x4xi1> +/// Perform the contraction +// CHECK: %[[D:.*]] = vector.mask %[[CONTRACTION_MASK]] +// CHECK-SAME: { vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK-SAME: } : vector<8x16x4xi1> -> vector<8x16xf32> + +/// Create a mask for the result +// CHECK: %[[D_OFFSET:.*]] = arith.constant 0 : index +// CHECK: %[[D_DIM_M_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[D_DIM_M:.*]] = tensor.dim %[[C]], %[[D_DIM_M_IDX]] : tensor +// CHECK: %[[D_DIM_N_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[D_DIM_N:.*]] = tensor.dim %[[C]], %[[D_DIM_N_IDX]] : tensor +// CHECK: %[[LOAD_D_MASK:.*]] = vector.create_mask +// CHECK-SAME: %[[D_DIM_M]], %[[D_DIM_N]] : vector<8x16xi1> +/// Write the result +// CHECK: vector.mask %[[LOAD_D_MASK]] +// CHECK-SAME: { vector.transfer_write %[[D]], %[[C]]{{\[}}%[[D_OFFSET]], %[[D_OFFSET]]{{\]}} +// CHECK-SAME: : vector<8x16xf32>, tensor } +// CHECK-SAME: : vector<8x16xi1> -> tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic_memref(%A: memref, %B: memref, + %C: memref) { + linalg.matmul + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic_memref( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref, vector<4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic_scalable(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic_scalable( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor, vector<8x4xf32> } +// CHECK-SAME: : vector<8x4xi1> -> vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor, vector<4x[16]xf32> } +// CHECK-SAME: : vector<4x[16]xi1> -> vector<4x[16]xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor, vector<8x[16]xf32> } +// CHECK-SAME: : vector<8x[16]xi1> -> vector<8x[16]xf32> +// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK-SAME: } : vector<8x[16]x4xi1> -> vector<8x[16]xf32> +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor } +// CHECK-SAME: : vector<8x[16]xi1> -> tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, [16], 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A + affine_map<(m, n, k) -> (n, k)>, // transpose B + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<4x8xf32>, tensor<16x4xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_transpose( +// CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>, %[[B:.*]]: tensor<16x4xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8xf32>, vector<4x8xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<16x4xf32>, vector<16x4xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic_transpose(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A + affine_map<(m, n, k) -> (n, k)>, // transpose B + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic_transpose( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor, vector<4x8xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor, vector<16x4xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +/// Contractions with arbitrarty broadcasts are not supported in contraction interface +/// vectorization. +/// Dimension broadcasts are expected to be decomposed first which removes ambiguity +/// caused by possible variants of dimensions materialization. +/// For example, whether the below target LHS input layout is (m, k) or (k, m). + +func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + // expected-error @+1 {{Attempted to vectorize, but failed}} + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<4xf32>, tensor<4x16xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + ins(%A, %B : tensor<8x4xf16>, tensor<4x16xf16>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_mixed_precision( +// CHECK-SAME: %[[A:.*]]: tensor<8x4xf16>, %[[B:.*]]: tensor<4x16xf16>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf16>, vector<8x4xf16> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf16>, vector<4x16xf16> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @batch_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>, + %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> { + %0 = linalg.batch_matmul + ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>) + outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> + return %0 : tensor<3x8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-LABEL: func.func @batch_matmul( +// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<3x8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @batch_reduce_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.batch_reduce_matmul + ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-LABEL: func.func @batch_reduce_matmul( +// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @contract(%A: tensor<4x8x2xf32>, %B: tensor<8x16x2xf32>, + %C: tensor<4x16xf32>) -> tensor<4x16xf32> { + %0 = linalg.contract + indexing_maps = [affine_map<(m, n, k, kk) -> (m, k, kk)>, + affine_map<(m, n, k, kk) -> (k, n, kk)>, + affine_map<(m, n, k, kk) -> (m, n)>] + ins(%A, %B : tensor<4x8x2xf32>, tensor<8x16x2xf32>) + outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32> + return %0 : tensor<4x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-LABEL: func.func @contract( +// CHECK-SAME: %[[A:.*]]: tensor<4x8x2xf32>, %[[B:.*]]: tensor<8x16x2xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<4x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf32>, vector<4x8x2xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf32>, vector<8x16x2xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +/// Generic can represent contractions but it does not implement contraction interface. +/// Thus, direct lowering to vector.contract is not supported. +/// Vectorization still works and applies generic rewrite logic. + +func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>) + outs(%C : tensor<8x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @negative_generic( +// CHECK-NOT: vector.contract +// CHECK: vector.multi_reduction + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +}