@@ -1816,6 +1816,72 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1816
1816
return success ();
1817
1817
}
1818
1818
1819
+ namespace {
1820
+ struct IteratorType {
1821
+ IteratorType (StringRef strRef) : strRef(strRef) {}
1822
+ bool isOfType (Attribute attr) const {
1823
+ auto sAttr = attr.dyn_cast <StringAttr>();
1824
+ return sAttr && sAttr .getValue () == strRef;
1825
+ }
1826
+ StringRef strRef;
1827
+ };
1828
+ struct Par : public IteratorType {
1829
+ Par () : IteratorType(getParallelIteratorTypeName()) {}
1830
+ };
1831
+ struct Red : public IteratorType {
1832
+ Red () : IteratorType(getReductionIteratorTypeName()) {}
1833
+ };
1834
+
1835
+ // Unroll outer-products along reduction.
1836
+ struct UnrolledOuterProductEmitter {
1837
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1838
+
1839
+ UnrolledOuterProductEmitter (PatternRewriter &rewriter,
1840
+ vector::ContractionOp op)
1841
+ : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
1842
+ iterators (op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
1843
+
1844
+ Value t (Value v) {
1845
+ static constexpr std::array<int64_t , 2 > perm = {1 , 0 };
1846
+ return rewriter.create <vector::TransposeOp>(loc, v, perm);
1847
+ }
1848
+
1849
+ bool iters (ArrayRef<IteratorType> its) {
1850
+ if (its.size () != iterators.size ())
1851
+ return false ;
1852
+ for (int i = 0 , e = its.size (); i != e; ++i) {
1853
+ if (!its[i].isOfType (iterators[i]))
1854
+ return false ;
1855
+ }
1856
+ return true ;
1857
+ }
1858
+
1859
+ bool layout (MapList l) {
1860
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList (m); };
1861
+ return maps == infer (l);
1862
+ }
1863
+
1864
+ LogicalResult outer_prod (Value lhs, Value rhs, Value res, int reductionSize) {
1865
+ assert (reductionSize > 0 );
1866
+ for (int64_t k = 0 ; k < reductionSize; ++k) {
1867
+ Value a = rewriter.create <vector::ExtractOp>(loc, lhs, k);
1868
+ Value b = rewriter.create <vector::ExtractOp>(loc, rhs, k);
1869
+ res = rewriter.create <vector::OuterProductOp>(loc, res.getType (), a, b,
1870
+ res, kind);
1871
+ }
1872
+ rewriter.replaceOp (op, res);
1873
+ return success ();
1874
+ }
1875
+
1876
+ PatternRewriter &rewriter;
1877
+ Location loc;
1878
+ vector::CombiningKind kind;
1879
+ ArrayAttr iterators;
1880
+ SmallVector<AffineMap, 4 > maps;
1881
+ Operation *op;
1882
+ };
1883
+ } // namespace
1884
+
1819
1885
// / Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1820
1886
// / semantics to a reduction_size-unrolled sequence:
1821
1887
// / ```
@@ -1844,104 +1910,64 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1844
1910
if (failed (filter (op)))
1845
1911
return failure ();
1846
1912
1847
- Location loc = op.getLoc ();
1848
- int64_t reductionSize = 0 ;
1849
1913
VectorType lhsType = op.getLhsType ();
1850
1914
Value lhs = op.lhs (), rhs = op.rhs (), res = op.acc ();
1851
1915
1852
1916
// Set up the parallel/reduction structure in right form.
1853
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1854
- auto infer = [](MapList m) { return AffineMap::inferFromExprList (m); };
1855
1917
AffineExpr m, n, k;
1856
1918
bindDims (rewriter.getContext (), m, n, k);
1857
- static constexpr std::array<int64_t , 2 > perm = {1 , 0 };
1858
- auto iteratorTypes = op.iterator_types ().getValue ();
1859
- SmallVector<AffineMap, 4 > maps = op.getIndexingMaps ();
1860
- if (isParallelIterator (iteratorTypes[0 ]) &&
1861
- isParallelIterator (iteratorTypes[1 ]) &&
1862
- isReductionIterator (iteratorTypes[2 ])) {
1863
- //
1864
- // Two outer parallel, one inner reduction (matmat flavor).
1865
- //
1866
- if (maps == infer ({{m, k}, {k, n}, {m, n}})) {
1867
- // This is the classical row-major matmul. Just permute the lhs.
1868
- reductionSize = lhsType.getDimSize (1 );
1869
- lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1870
- } else if (maps == infer ({{m, k}, {n, k}, {m, n}})) {
1871
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1872
- reductionSize = lhsType.getDimSize (1 );
1873
- lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1874
- rhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1875
- } else if (maps == infer ({{k, m}, {k, n}, {m, n}})) {
1876
- // No need to permute anything.
1877
- reductionSize = lhsType.getDimSize (0 );
1878
- } else if (maps == infer ({{k, m}, {n, k}, {m, n}})) {
1879
- // Just permute the rhs.
1880
- reductionSize = lhsType.getDimSize (0 );
1881
- rhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1882
- } else if (maps == infer ({{m, k}, {k, n}, {n, m}})) {
1883
- // This is the classical row-major matmul. Just permute the lhs.
1884
- reductionSize = lhsType.getDimSize (1 );
1885
- Value tmp = rhs;
1886
- rhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1887
- lhs = tmp;
1888
- } else if (maps == infer ({{m, k}, {n, k}, {n, m}})) {
1889
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1890
- reductionSize = lhsType.getDimSize (1 );
1891
- Value tmp = rhs;
1892
- rhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1893
- lhs = rewriter.create <vector::TransposeOp>(loc, tmp, perm);
1894
- } else if (maps == infer ({{k, m}, {k, n}, {n, m}})) {
1895
- // No need to permute anything, but still swap lhs and rhs.
1896
- reductionSize = lhsType.getDimSize (0 );
1897
- std::swap (lhs, rhs);
1898
- } else if (maps == infer ({{k, m}, {n, k}, {n, m}})) {
1899
- // Just permute the rhs.
1900
- reductionSize = lhsType.getDimSize (0 );
1901
- Value tmp = lhs;
1902
- lhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1903
- rhs = tmp;
1904
- } else {
1905
- return failure ();
1906
- }
1907
- } else if (isParallelIterator (iteratorTypes[0 ]) &&
1908
- isReductionIterator (iteratorTypes[1 ])) {
1909
- //
1910
- // One outer parallel, one inner reduction (matvec flavor)
1911
- //
1912
- if (maps == infer ({{m, n}, {n}, {m}})) {
1913
- // Case mat-vec: transpose.
1914
- reductionSize = lhsType.getDimSize (1 );
1915
- lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1916
- } else if (maps == infer ({{n, m}, {n}, {m}})) {
1917
- // Case mat-trans-vec: ready to go.
1918
- reductionSize = lhsType.getDimSize (0 );
1919
- } else if (maps == infer ({{n}, {m, n}, {m}})) {
1920
- // Case vec-mat: swap and transpose.
1921
- reductionSize = lhsType.getDimSize (0 );
1922
- std::swap (lhs, rhs);
1923
- lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1924
- } else if (maps == infer ({{n}, {n, m}, {m}})) {
1925
- // Case vec-mat-trans: swap and ready to go.
1926
- reductionSize = lhsType.getDimSize (0 );
1927
- std::swap (lhs, rhs);
1928
- } else {
1929
- return failure ();
1930
- }
1931
- } else {
1919
+
1920
+ //
1921
+ // Two outer parallel, one inner reduction (matmat flavor).
1922
+ //
1923
+ UnrolledOuterProductEmitter e (rewriter, op);
1924
+ if (e.iters ({Par (), Par (), Red ()})) {
1925
+ // Classical row-major matmul: Just permute the lhs.
1926
+ if (e.layout ({{m, k}, {k, n}, {m, n}}))
1927
+ return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1928
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1929
+ if (e.layout ({{m, k}, {n, k}, {m, n}}))
1930
+ return e.outer_prod (e.t (lhs), e.t (rhs), res, lhsType.getDimSize (1 ));
1931
+ // No need to permute anything.
1932
+ if (e.layout ({{k, m}, {k, n}, {m, n}}))
1933
+ return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1934
+ // Just permute the rhs.
1935
+ if (e.layout ({{k, m}, {n, k}, {m, n}}))
1936
+ return e.outer_prod (lhs, e.t (rhs), res, lhsType.getDimSize (0 ));
1937
+ // Transposed output: swap RHS and LHS.
1938
+ // Classical row-major matmul: permute the lhs.
1939
+ if (e.layout ({{m, k}, {k, n}, {n, m}}))
1940
+ return e.outer_prod (rhs, e.t (lhs), res, lhsType.getDimSize (1 ));
1941
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1942
+ if (e.layout ({{m, k}, {n, k}, {n, m}}))
1943
+ return e.outer_prod (e.t (rhs), e.t (lhs), res, lhsType.getDimSize (1 ));
1944
+ if (e.layout ({{k, m}, {k, n}, {n, m}}))
1945
+ return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1946
+ if (e.layout ({{k, m}, {n, k}, {n, m}}))
1947
+ return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1932
1948
return failure ();
1933
1949
}
1934
- assert (reductionSize > 0 );
1935
-
1936
- // Unroll outer-products along reduction.
1937
- for (int64_t k = 0 ; k < reductionSize; ++k) {
1938
- Value a = rewriter.create <vector::ExtractOp>(op.getLoc (), lhs, k);
1939
- Value b = rewriter.create <vector::ExtractOp>(op.getLoc (), rhs, k);
1940
- res = rewriter.create <vector::OuterProductOp>(op.getLoc (), res.getType (), a,
1941
- b, res, op.kind ());
1950
+
1951
+ //
1952
+ // One outer parallel, one inner reduction (matvec flavor)
1953
+ //
1954
+ if (e.iters ({Par (), Red ()})) {
1955
+ // Case mat-vec: transpose.
1956
+ if (e.layout ({{m, n}, {n}, {m}}))
1957
+ return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1958
+ // Case mat-trans-vec: ready to go.
1959
+ if (e.layout ({{n, m}, {n}, {m}}))
1960
+ return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1961
+ // Case vec-mat: swap and transpose.
1962
+ if (e.layout ({{n}, {m, n}, {m}}))
1963
+ return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1964
+ // Case vec-mat-trans: swap and ready to go.
1965
+ if (e.layout ({{n}, {n, m}, {m}}))
1966
+ return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1967
+ return failure ();
1942
1968
}
1943
- rewriter. replaceOp (op, res);
1944
- return success ();
1969
+
1970
+ return failure ();
1945
1971
}
1946
1972
1947
1973
LogicalResult
0 commit comments