Skip to content

Commit c08e411

Browse files
committed
[mlir][linalg] Introduce transpose semantic to 'linalg.matmul'.
The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows. By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand. Example: ``` %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [0, 1] ```
1 parent 27b6080 commit c08e411

File tree

8 files changed

+347
-165
lines changed

8 files changed

+347
-165
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
10651065
- !ScalarExpression
10661066
scalar_arg: rhs
10671067
--- !LinalgOpConfig
1068-
metadata: !LinalgOpMetadata
1069-
name: matmul
1070-
cpp_class_name: MatmulOp
1071-
doc: |-
1072-
Performs a matrix multiplication of two 2D inputs.
1073-
1074-
Numeric casting is performed on the operands to the inner multiply, promoting
1075-
them to the same data type as the accumulator/output.
1076-
implements:
1077-
- LinalgContractionOpInterface
1078-
structured_op: !LinalgStructuredOpConfig
1079-
args:
1080-
- !LinalgOperandDefConfig
1081-
name: A
1082-
kind: input_tensor
1083-
type_var: T1
1084-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1085-
- !LinalgOperandDefConfig
1086-
name: B
1087-
kind: input_tensor
1088-
type_var: T2
1089-
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
1090-
- !LinalgOperandDefConfig
1091-
name: C
1092-
kind: output_tensor
1093-
type_var: U
1094-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1095-
- !LinalgOperandDefConfig
1096-
name: cast
1097-
kind: type_fn_attr
1098-
default_fn: cast_signed
1099-
indexing_maps: !LinalgIndexingMapsConfig
1100-
static_indexing_maps:
1101-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1102-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
1103-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1104-
iterator_types:
1105-
- parallel
1106-
- parallel
1107-
- reduction
1108-
assignments:
1109-
- !ScalarAssign
1110-
arg: C
1111-
value: !ScalarExpression
1112-
scalar_fn:
1113-
kind: binary
1114-
fn_name: add
1115-
operands:
1116-
- !ScalarExpression
1117-
scalar_arg: C
1118-
- !ScalarExpression
1119-
scalar_fn:
1120-
kind: binary
1121-
fn_name: mul
1122-
operands:
1123-
- !ScalarExpression
1124-
scalar_fn:
1125-
kind: type
1126-
attr_name: cast
1127-
type_var: U
1128-
operands:
1129-
- !ScalarExpression
1130-
scalar_arg: A
1131-
- !ScalarExpression
1132-
scalar_fn:
1133-
kind: type
1134-
attr_name: cast
1135-
type_var: U
1136-
operands:
1137-
- !ScalarExpression
1138-
scalar_arg: B
1139-
--- !LinalgOpConfig
11401068
metadata: !LinalgOpMetadata
11411069
name: quantized_matmul
11421070
cpp_class_name: QuantizedMatmulOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,5 @@ def Linalg_WinogradOutputTransformOp :
271271
let hasVerifier = 1;
272272
}
273273

274+
274275
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,106 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
534534
let hasCanonicalizer = 1;
535535
}
536536

537+
//===----------------------------------------------------------------------===//
538+
// Op definition for MatmulOp
539+
//===----------------------------------------------------------------------===//
540+
541+
def MatmulOp : LinalgStructuredBase_Op<"matmul", !listconcat([AttrSizedOperandSegments],
542+
/*extraInterfaces=*/[LinalgContractionOpInterface])> {
543+
544+
let summary = [{Performs a matrix multiplication of two 2D inputs without transpose.}];
545+
let description = [{Numeric casting is performed on the operands to the inner multiply,
546+
promoting them to the same data type as the accumulator/output.
547+
548+
Per input operand transpose can be performed by specifying the required permutation
549+
attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for
550+
each operand explicitly. By default, no transpose is mandated for each input operand.
551+
552+
Example:
553+
```
554+
%val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
555+
outs(%arg2: memref<3x7xf32>)
556+
permutationA = [1, 0]
557+
permutationB = [0, 1]
558+
```
559+
}];
560+
561+
let arguments = (ins
562+
Variadic<AnyType>:$inputs,
563+
Variadic<AnyShaped>:$outputs,
564+
ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationA,
565+
ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationB,
566+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
567+
);
568+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
569+
let regions = (region AnyRegion:$region);
570+
571+
let skipDefaultBuilders = 1;
572+
let builders = [
573+
OpBuilder<
574+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
575+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
576+
[{
577+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
578+
attributes, MatmulOp::getRegionBuilder());
579+
}]>,
580+
OpBuilder<
581+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
582+
"ValueRange":$outputs,
583+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
584+
[{
585+
buildStructuredOp($_builder, $_state, resultTensorTypes,
586+
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
587+
}]>,
588+
OpBuilder<
589+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
590+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
591+
[{
592+
$_state.addOperands(operands);
593+
$_state.addAttributes(attributes);
594+
$_state.addTypes(resultTensorTypes);
595+
(void)$_state.addRegion();
596+
}]>,
597+
OpBuilder<
598+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
599+
"ValueRange":$outputs, "DenseI64ArrayAttr":$permutationA, "DenseI64ArrayAttr":$permutationB, "Attribute":$cast,
600+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
601+
[{
602+
$_state.addAttribute("permutationA", permutationA);
603+
$_state.addAttribute("permutationB", permutationB);
604+
$_state.addAttribute("cast", cast);
605+
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
606+
attributes, MatmulOp::getRegionBuilder());
607+
}]>
608+
609+
];
610+
let hasCustomAssemblyFormat = 1;
611+
let hasFolder = 1;
612+
613+
614+
let extraClassDeclaration = structuredOpsBaseDecls # [{
615+
// Auto-generated.
616+
SmallVector<utils::IteratorType> getIteratorTypesArray();
617+
ArrayAttr getIndexingMaps();
618+
static void regionBuilder(ImplicitLocOpBuilder &b,
619+
Block &block, ArrayRef<NamedAttribute> attrs);
620+
static std::function<void(ImplicitLocOpBuilder &,
621+
Block &, ArrayRef<NamedAttribute>)>
622+
getRegionBuilder() {
623+
return regionBuilder;
624+
}
625+
626+
::mlir::MutableOperandRange getDpsInitsMutable() {
627+
return getOutputsMutable();
628+
}
629+
630+
// Generic methods.
631+
static unsigned getNumRegionArgs();
632+
std::string getLibraryCallName();
633+
bool hasDynamicIndexingMaps();
634+
}];
635+
}
636+
537637
//===----------------------------------------------------------------------===//
538638
// Named Linalg ops, implemented as a declarative configurations of generic ops.
539639
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,26 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
303303
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304304
return failure();
305305

306+
if (parser.parseOptionalKeyword("permutationA").succeeded()) {
307+
if (parser.parseEqual())
308+
return failure();
309+
310+
result.attributes.set("permutationA",
311+
DenseI64ArrayAttr::parse(parser, Type{}));
312+
}
313+
314+
if (parser.parseOptionalKeyword("permutationB").succeeded()) {
315+
if (parser.parseEqual())
316+
return failure();
317+
318+
result.attributes.set("permutationB",
319+
DenseI64ArrayAttr::parse(parser, Type{}));
320+
}
321+
322+
// Parse optional attributes.
323+
if (parser.parseOptionalAttrDict(result.attributes))
324+
return failure();
325+
306326
// TODO: consider merging results parsing into region parsing.
307327
// Need to wait for declarative assembly resolution to decide.
308328
SmallVector<Type, 1> outputTensorsTypes;
@@ -334,7 +354,8 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
334354
/*elidedAttrs=*/{"operandSegmentSizes",
335355
// See generated code in
336356
// LinalgNamedStructuredOps.yamlgen.cpp.inc
337-
"linalg.memoized_indexing_maps"});
357+
"linalg.memoized_indexing_maps", "permutationA",
358+
"permutationB"});
338359

339360
// Printing is shared with generic ops, except for the region and
340361
// attributes.
@@ -2980,3 +3001,132 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
29803001
Location loc) {
29813002
return arith::ConstantOp::materialize(builder, value, type, loc);
29823003
}
3004+
3005+
namespace mlir {
3006+
namespace linalg {
3007+
//===----------------------------------------------------------------------===//
3008+
// MatMulOp
3009+
//===----------------------------------------------------------------------===//
3010+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3011+
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3012+
utils::IteratorType::parallel,
3013+
utils::IteratorType::reduction};
3014+
}
3015+
3016+
ArrayAttr MatmulOp::getIndexingMaps() {
3017+
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
3018+
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
3019+
if (cached)
3020+
return cached;
3021+
3022+
MLIRContext *context = getContext();
3023+
SmallVector<AffineMap> maps;
3024+
3025+
unsigned numResults;
3026+
SmallVector<AffineExpr, 3> dimReplacements;
3027+
AffineMap originalMap =
3028+
llvm::cast<AffineMapAttr>(
3029+
mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context))
3030+
.getValue();
3031+
numResults = originalMap.getNumResults();
3032+
for (unsigned i = 0; i < numResults; i++) {
3033+
AffineExpr expr = originalMap.getResult(getPermutationA()[i]);
3034+
dimReplacements.push_back(expr);
3035+
}
3036+
3037+
AffineMap newMap =
3038+
AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
3039+
dimReplacements, context);
3040+
maps.push_back(newMap);
3041+
maps.back() =
3042+
simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
3043+
3044+
originalMap =
3045+
llvm::cast<AffineMapAttr>(
3046+
mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context))
3047+
.getValue();
3048+
numResults = originalMap.getNumResults();
3049+
dimReplacements.clear();
3050+
for (unsigned i = 0; i < numResults; i++) {
3051+
AffineExpr expr = originalMap.getResult(getPermutationB()[i]);
3052+
dimReplacements.push_back(expr);
3053+
}
3054+
3055+
newMap = AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
3056+
dimReplacements, context);
3057+
maps.push_back(newMap);
3058+
maps.back() =
3059+
simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
3060+
3061+
maps.push_back(
3062+
llvm::cast<AffineMapAttr>(
3063+
mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context))
3064+
.getValue());
3065+
maps.back() =
3066+
simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
3067+
cached = Builder(context).getAffineMapArrayAttr(maps);
3068+
getOperation()->setAttr(memoizeAttr, cached);
3069+
return cached;
3070+
}
3071+
3072+
unsigned MatmulOp::getNumRegionArgs() { return 3; }
3073+
3074+
std::string MatmulOp::getLibraryCallName() {
3075+
return generateLibraryCallName(getOperation());
3076+
}
3077+
3078+
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3079+
3080+
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3081+
ArrayRef<NamedAttribute> attrs) {
3082+
assert(3 > 0 && block.getNumArguments() == 3 &&
3083+
"MatmulOp regionBuilder expects 3 (>=0) args");
3084+
RegionBuilderHelper helper(b, block);
3085+
SmallVector<Value> yields;
3086+
3087+
TypeFn castVal = TypeFn::cast_signed;
3088+
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3089+
return attr.getName() == "cast";
3090+
});
3091+
if (castIter != attrs.end()) {
3092+
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3093+
castVal = attr.getValue();
3094+
}
3095+
3096+
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3097+
block.getArgument(0));
3098+
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3099+
block.getArgument(1));
3100+
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3101+
Value value4 =
3102+
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3103+
yields.push_back(value4);
3104+
helper.yieldOutputs(yields);
3105+
}
3106+
3107+
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3108+
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3109+
MatmulOp::getRegionBuilder());
3110+
}
3111+
void MatmulOp::print(OpAsmPrinter &p) {
3112+
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
3113+
if (!getPermutationA().empty())
3114+
printDenseI64ArrayAttr(p, getPermutationAAttrName(), getPermutationA());
3115+
3116+
if (!getPermutationB().empty())
3117+
printDenseI64ArrayAttr(p, getPermutationBAttrName(), getPermutationB());
3118+
}
3119+
3120+
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3121+
return memref::foldMemRefCast(*this);
3122+
}
3123+
void MatmulOp::getEffects(
3124+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3125+
&effects) {
3126+
if (hasPureTensorSemantics())
3127+
return;
3128+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3129+
}
3130+
3131+
} // namespace linalg
3132+
} // namespace mlir

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -383,23 +383,6 @@ def select(
383383
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
384384

385385

386-
@linalg_structured_op
387-
def matmul(
388-
A=TensorDef(T1, S.M, S.K),
389-
B=TensorDef(T2, S.K, S.N),
390-
C=TensorDef(U, S.M, S.N, output=True),
391-
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
392-
):
393-
"""Performs a matrix multiplication of two 2D inputs.
394-
395-
Numeric casting is performed on the operands to the inner multiply, promoting
396-
them to the same data type as the accumulator/output.
397-
"""
398-
domain(D.m, D.n, D.k)
399-
implements(ContractionOpInterface)
400-
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
401-
402-
403386
@linalg_structured_op
404387
def quantized_matmul(
405388
A=TensorDef(T1, S.M, S.K),

0 commit comments

Comments
 (0)