@@ -1816,72 +1816,6 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1816
1816
return success ();
1817
1817
}
1818
1818
1819
- namespace {
1820
- struct IteratorType {
1821
- IteratorType (StringRef strRef) : strRef(strRef) {}
1822
- bool isOfType (Attribute attr) const {
1823
- auto sAttr = attr.dyn_cast <StringAttr>();
1824
- return sAttr && sAttr .getValue () == strRef;
1825
- }
1826
- StringRef strRef;
1827
- };
1828
- struct Par : public IteratorType {
1829
- Par () : IteratorType(getParallelIteratorTypeName()) {}
1830
- };
1831
- struct Red : public IteratorType {
1832
- Red () : IteratorType(getReductionIteratorTypeName()) {}
1833
- };
1834
-
1835
- // Unroll outer-products along reduction.
1836
- struct UnrolledOuterProductEmitter {
1837
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1838
-
1839
- UnrolledOuterProductEmitter (PatternRewriter &rewriter,
1840
- vector::ContractionOp op)
1841
- : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
1842
- iterators (op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
1843
-
1844
- Value t (Value v) {
1845
- static constexpr std::array<int64_t , 2 > perm = {1 , 0 };
1846
- return rewriter.create <vector::TransposeOp>(loc, v, perm);
1847
- }
1848
-
1849
- bool iters (ArrayRef<IteratorType> its) {
1850
- if (its.size () != iterators.size ())
1851
- return false ;
1852
- for (int i = 0 , e = its.size (); i != e; ++i) {
1853
- if (!its[i].isOfType (iterators[i]))
1854
- return false ;
1855
- }
1856
- return true ;
1857
- }
1858
-
1859
- bool layout (MapList l) {
1860
- auto infer = [](MapList m) { return AffineMap::inferFromExprList (m); };
1861
- return maps == infer (l);
1862
- }
1863
-
1864
- LogicalResult outer_prod (Value lhs, Value rhs, Value res, int reductionSize) {
1865
- assert (reductionSize > 0 );
1866
- for (int64_t k = 0 ; k < reductionSize; ++k) {
1867
- Value a = rewriter.create <vector::ExtractOp>(loc, lhs, k);
1868
- Value b = rewriter.create <vector::ExtractOp>(loc, rhs, k);
1869
- res = rewriter.create <vector::OuterProductOp>(loc, res.getType (), a, b,
1870
- res, kind);
1871
- }
1872
- rewriter.replaceOp (op, res);
1873
- return success ();
1874
- }
1875
-
1876
- PatternRewriter &rewriter;
1877
- Location loc;
1878
- vector::CombiningKind kind;
1879
- ArrayAttr iterators;
1880
- SmallVector<AffineMap, 4 > maps;
1881
- Operation *op;
1882
- };
1883
- } // namespace
1884
-
1885
1819
// / Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1886
1820
// / semantics to a reduction_size-unrolled sequence:
1887
1821
// / ```
@@ -1910,64 +1844,104 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1910
1844
if (failed (filter (op)))
1911
1845
return failure ();
1912
1846
1847
+ Location loc = op.getLoc ();
1848
+ int64_t reductionSize = 0 ;
1913
1849
VectorType lhsType = op.getLhsType ();
1914
1850
Value lhs = op.lhs (), rhs = op.rhs (), res = op.acc ();
1915
1851
1916
1852
// Set up the parallel/reduction structure in right form.
1853
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1854
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList (m); };
1917
1855
AffineExpr m, n, k;
1918
1856
bindDims (rewriter.getContext (), m, n, k);
1919
-
1920
- //
1921
- // Two outer parallel, one inner reduction (matmat flavor).
1922
- //
1923
- UnrolledOuterProductEmitter e (rewriter, op);
1924
- if (e.iters ({Par (), Par (), Red ()})) {
1925
- // Classical row-major matmul: Just permute the lhs.
1926
- if (e.layout ({{m, k}, {k, n}, {m, n}}))
1927
- return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1928
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1929
- if (e.layout ({{m, k}, {n, k}, {m, n}}))
1930
- return e.outer_prod (e.t (lhs), e.t (rhs), res, lhsType.getDimSize (1 ));
1931
- // No need to permute anything.
1932
- if (e.layout ({{k, m}, {k, n}, {m, n}}))
1933
- return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1934
- // Just permute the rhs.
1935
- if (e.layout ({{k, m}, {n, k}, {m, n}}))
1936
- return e.outer_prod (lhs, e.t (rhs), res, lhsType.getDimSize (0 ));
1937
- // Transposed output: swap RHS and LHS.
1938
- // Classical row-major matmul: permute the lhs.
1939
- if (e.layout ({{m, k}, {k, n}, {n, m}}))
1940
- return e.outer_prod (rhs, e.t (lhs), res, lhsType.getDimSize (1 ));
1941
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1942
- if (e.layout ({{m, k}, {n, k}, {n, m}}))
1943
- return e.outer_prod (e.t (rhs), e.t (lhs), res, lhsType.getDimSize (1 ));
1944
- if (e.layout ({{k, m}, {k, n}, {n, m}}))
1945
- return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1946
- if (e.layout ({{k, m}, {n, k}, {n, m}}))
1947
- return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1857
+ static constexpr std::array<int64_t , 2 > perm = {1 , 0 };
1858
+ auto iteratorTypes = op.iterator_types ().getValue ();
1859
+ SmallVector<AffineMap, 4 > maps = op.getIndexingMaps ();
1860
+ if (isParallelIterator (iteratorTypes[0 ]) &&
1861
+ isParallelIterator (iteratorTypes[1 ]) &&
1862
+ isReductionIterator (iteratorTypes[2 ])) {
1863
+ //
1864
+ // Two outer parallel, one inner reduction (matmat flavor).
1865
+ //
1866
+ if (maps == infer ({{m, k}, {k, n}, {m, n}})) {
1867
+ // This is the classical row-major matmul. Just permute the lhs.
1868
+ reductionSize = lhsType.getDimSize (1 );
1869
+ lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1870
+ } else if (maps == infer ({{m, k}, {n, k}, {m, n}})) {
1871
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1872
+ reductionSize = lhsType.getDimSize (1 );
1873
+ lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1874
+ rhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1875
+ } else if (maps == infer ({{k, m}, {k, n}, {m, n}})) {
1876
+ // No need to permute anything.
1877
+ reductionSize = lhsType.getDimSize (0 );
1878
+ } else if (maps == infer ({{k, m}, {n, k}, {m, n}})) {
1879
+ // Just permute the rhs.
1880
+ reductionSize = lhsType.getDimSize (0 );
1881
+ rhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1882
+ } else if (maps == infer ({{m, k}, {k, n}, {n, m}})) {
1883
+ // This is the classical row-major matmul. Just permute the lhs.
1884
+ reductionSize = lhsType.getDimSize (1 );
1885
+ Value tmp = rhs;
1886
+ rhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1887
+ lhs = tmp;
1888
+ } else if (maps == infer ({{m, k}, {n, k}, {n, m}})) {
1889
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1890
+ reductionSize = lhsType.getDimSize (1 );
1891
+ Value tmp = rhs;
1892
+ rhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1893
+ lhs = rewriter.create <vector::TransposeOp>(loc, tmp, perm);
1894
+ } else if (maps == infer ({{k, m}, {k, n}, {n, m}})) {
1895
+ // No need to permute anything, but still swap lhs and rhs.
1896
+ reductionSize = lhsType.getDimSize (0 );
1897
+ std::swap (lhs, rhs);
1898
+ } else if (maps == infer ({{k, m}, {n, k}, {n, m}})) {
1899
+ // Just permute the rhs.
1900
+ reductionSize = lhsType.getDimSize (0 );
1901
+ Value tmp = lhs;
1902
+ lhs = rewriter.create <vector::TransposeOp>(loc, rhs, perm);
1903
+ rhs = tmp;
1904
+ } else {
1905
+ return failure ();
1906
+ }
1907
+ } else if (isParallelIterator (iteratorTypes[0 ]) &&
1908
+ isReductionIterator (iteratorTypes[1 ])) {
1909
+ //
1910
+ // One outer parallel, one inner reduction (matvec flavor)
1911
+ //
1912
+ if (maps == infer ({{m, n}, {n}, {m}})) {
1913
+ // Case mat-vec: transpose.
1914
+ reductionSize = lhsType.getDimSize (1 );
1915
+ lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1916
+ } else if (maps == infer ({{n, m}, {n}, {m}})) {
1917
+ // Case mat-trans-vec: ready to go.
1918
+ reductionSize = lhsType.getDimSize (0 );
1919
+ } else if (maps == infer ({{n}, {m, n}, {m}})) {
1920
+ // Case vec-mat: swap and transpose.
1921
+ reductionSize = lhsType.getDimSize (0 );
1922
+ std::swap (lhs, rhs);
1923
+ lhs = rewriter.create <vector::TransposeOp>(loc, lhs, perm);
1924
+ } else if (maps == infer ({{n}, {n, m}, {m}})) {
1925
+ // Case vec-mat-trans: swap and ready to go.
1926
+ reductionSize = lhsType.getDimSize (0 );
1927
+ std::swap (lhs, rhs);
1928
+ } else {
1929
+ return failure ();
1930
+ }
1931
+ } else {
1948
1932
return failure ();
1949
1933
}
1950
-
1951
- //
1952
- // One outer parallel, one inner reduction (matvec flavor)
1953
- //
1954
- if (e.iters ({Par (), Red ()})) {
1955
- // Case mat-vec: transpose.
1956
- if (e.layout ({{m, n}, {n}, {m}}))
1957
- return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1958
- // Case mat-trans-vec: ready to go.
1959
- if (e.layout ({{n, m}, {n}, {m}}))
1960
- return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1961
- // Case vec-mat: swap and transpose.
1962
- if (e.layout ({{n}, {m, n}, {m}}))
1963
- return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1964
- // Case vec-mat-trans: swap and ready to go.
1965
- if (e.layout ({{n}, {n, m}, {m}}))
1966
- return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1967
- return failure ();
1934
+ assert (reductionSize > 0 );
1935
+
1936
+ // Unroll outer-products along reduction.
1937
+ for (int64_t k = 0 ; k < reductionSize; ++k) {
1938
+ Value a = rewriter.create <vector::ExtractOp>(op.getLoc (), lhs, k);
1939
+ Value b = rewriter.create <vector::ExtractOp>(op.getLoc (), rhs, k);
1940
+ res = rewriter.create <vector::OuterProductOp>(op.getLoc (), res.getType (), a,
1941
+ b, res, op.kind ());
1968
1942
}
1969
-
1970
- return failure ();
1943
+ rewriter. replaceOp (op, res);
1944
+ return success ();
1971
1945
}
1972
1946
1973
1947
LogicalResult
0 commit comments