Skip to content

Commit db188ad

Browse files
[mlir][Vector] NFC - Compress vector to outerproduct lowering.
The implementation has become too unwieldy and cognitive overhead wins. Instead compress the implementation in preparation for additional lowering paths. Differential Revision: https://reviews.llvm.org/D105359
1 parent 21e9261 commit db188ad

File tree

1 file changed

+115
-89
lines changed

1 file changed

+115
-89
lines changed

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 115 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,6 +1816,72 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
18161816
return success();
18171817
}
18181818

1819+
namespace {
1820+
struct IteratorType {
1821+
IteratorType(StringRef strRef) : strRef(strRef) {}
1822+
bool isOfType(Attribute attr) const {
1823+
auto sAttr = attr.dyn_cast<StringAttr>();
1824+
return sAttr && sAttr.getValue() == strRef;
1825+
}
1826+
StringRef strRef;
1827+
};
1828+
struct Par : public IteratorType {
1829+
Par() : IteratorType(getParallelIteratorTypeName()) {}
1830+
};
1831+
struct Red : public IteratorType {
1832+
Red() : IteratorType(getReductionIteratorTypeName()) {}
1833+
};
1834+
1835+
// Unroll outer-products along reduction.
1836+
struct UnrolledOuterProductEmitter {
1837+
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1838+
1839+
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
1840+
vector::ContractionOp op)
1841+
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
1842+
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
1843+
1844+
Value t(Value v) {
1845+
static constexpr std::array<int64_t, 2> perm = {1, 0};
1846+
return rewriter.create<vector::TransposeOp>(loc, v, perm);
1847+
}
1848+
1849+
bool iters(ArrayRef<IteratorType> its) {
1850+
if (its.size() != iterators.size())
1851+
return false;
1852+
for (int i = 0, e = its.size(); i != e; ++i) {
1853+
if (!its[i].isOfType(iterators[i]))
1854+
return false;
1855+
}
1856+
return true;
1857+
}
1858+
1859+
bool layout(MapList l) {
1860+
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1861+
return maps == infer(l);
1862+
}
1863+
1864+
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
1865+
assert(reductionSize > 0);
1866+
for (int64_t k = 0; k < reductionSize; ++k) {
1867+
Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
1868+
Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
1869+
res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1870+
res, kind);
1871+
}
1872+
rewriter.replaceOp(op, res);
1873+
return success();
1874+
}
1875+
1876+
PatternRewriter &rewriter;
1877+
Location loc;
1878+
vector::CombiningKind kind;
1879+
ArrayAttr iterators;
1880+
SmallVector<AffineMap, 4> maps;
1881+
Operation *op;
1882+
};
1883+
} // namespace
1884+
18191885
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
18201886
/// semantics to a reduction_size-unrolled sequence:
18211887
/// ```
@@ -1844,104 +1910,64 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
18441910
if (failed(filter(op)))
18451911
return failure();
18461912

1847-
Location loc = op.getLoc();
1848-
int64_t reductionSize = 0;
18491913
VectorType lhsType = op.getLhsType();
18501914
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
18511915

18521916
// Set up the parallel/reduction structure in right form.
1853-
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1854-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
18551917
AffineExpr m, n, k;
18561918
bindDims(rewriter.getContext(), m, n, k);
1857-
static constexpr std::array<int64_t, 2> perm = {1, 0};
1858-
auto iteratorTypes = op.iterator_types().getValue();
1859-
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1860-
if (isParallelIterator(iteratorTypes[0]) &&
1861-
isParallelIterator(iteratorTypes[1]) &&
1862-
isReductionIterator(iteratorTypes[2])) {
1863-
//
1864-
// Two outer parallel, one inner reduction (matmat flavor).
1865-
//
1866-
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1867-
// This is the classical row-major matmul. Just permute the lhs.
1868-
reductionSize = lhsType.getDimSize(1);
1869-
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1870-
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1871-
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1872-
reductionSize = lhsType.getDimSize(1);
1873-
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1874-
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1875-
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1876-
// No need to permute anything.
1877-
reductionSize = lhsType.getDimSize(0);
1878-
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1879-
// Just permute the rhs.
1880-
reductionSize = lhsType.getDimSize(0);
1881-
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1882-
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1883-
// This is the classical row-major matmul. Just permute the lhs.
1884-
reductionSize = lhsType.getDimSize(1);
1885-
Value tmp = rhs;
1886-
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1887-
lhs = tmp;
1888-
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1889-
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1890-
reductionSize = lhsType.getDimSize(1);
1891-
Value tmp = rhs;
1892-
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1893-
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1894-
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1895-
// No need to permute anything, but still swap lhs and rhs.
1896-
reductionSize = lhsType.getDimSize(0);
1897-
std::swap(lhs, rhs);
1898-
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1899-
// Just permute the rhs.
1900-
reductionSize = lhsType.getDimSize(0);
1901-
Value tmp = lhs;
1902-
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1903-
rhs = tmp;
1904-
} else {
1905-
return failure();
1906-
}
1907-
} else if (isParallelIterator(iteratorTypes[0]) &&
1908-
isReductionIterator(iteratorTypes[1])) {
1909-
//
1910-
// One outer parallel, one inner reduction (matvec flavor)
1911-
//
1912-
if (maps == infer({{m, n}, {n}, {m}})) {
1913-
// Case mat-vec: transpose.
1914-
reductionSize = lhsType.getDimSize(1);
1915-
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1916-
} else if (maps == infer({{n, m}, {n}, {m}})) {
1917-
// Case mat-trans-vec: ready to go.
1918-
reductionSize = lhsType.getDimSize(0);
1919-
} else if (maps == infer({{n}, {m, n}, {m}})) {
1920-
// Case vec-mat: swap and transpose.
1921-
reductionSize = lhsType.getDimSize(0);
1922-
std::swap(lhs, rhs);
1923-
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1924-
} else if (maps == infer({{n}, {n, m}, {m}})) {
1925-
// Case vec-mat-trans: swap and ready to go.
1926-
reductionSize = lhsType.getDimSize(0);
1927-
std::swap(lhs, rhs);
1928-
} else {
1929-
return failure();
1930-
}
1931-
} else {
1919+
1920+
//
1921+
// Two outer parallel, one inner reduction (matmat flavor).
1922+
//
1923+
UnrolledOuterProductEmitter e(rewriter, op);
1924+
if (e.iters({Par(), Par(), Red()})) {
1925+
// Classical row-major matmul: Just permute the lhs.
1926+
if (e.layout({{m, k}, {k, n}, {m, n}}))
1927+
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
1928+
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1929+
if (e.layout({{m, k}, {n, k}, {m, n}}))
1930+
return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1));
1931+
// No need to permute anything.
1932+
if (e.layout({{k, m}, {k, n}, {m, n}}))
1933+
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
1934+
// Just permute the rhs.
1935+
if (e.layout({{k, m}, {n, k}, {m, n}}))
1936+
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
1937+
// Transposed output: swap RHS and LHS.
1938+
// Classical row-major matmul: permute the lhs.
1939+
if (e.layout({{m, k}, {k, n}, {n, m}}))
1940+
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
1941+
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1942+
if (e.layout({{m, k}, {n, k}, {n, m}}))
1943+
return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1));
1944+
if (e.layout({{k, m}, {k, n}, {n, m}}))
1945+
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
1946+
if (e.layout({{k, m}, {n, k}, {n, m}}))
1947+
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
19321948
return failure();
19331949
}
1934-
assert(reductionSize > 0);
1935-
1936-
// Unroll outer-products along reduction.
1937-
for (int64_t k = 0; k < reductionSize; ++k) {
1938-
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1939-
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1940-
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
1941-
b, res, op.kind());
1950+
1951+
//
1952+
// One outer parallel, one inner reduction (matvec flavor)
1953+
//
1954+
if (e.iters({Par(), Red()})) {
1955+
// Case mat-vec: transpose.
1956+
if (e.layout({{m, n}, {n}, {m}}))
1957+
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
1958+
// Case mat-trans-vec: ready to go.
1959+
if (e.layout({{n, m}, {n}, {m}}))
1960+
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
1961+
// Case vec-mat: swap and transpose.
1962+
if (e.layout({{n}, {m, n}, {m}}))
1963+
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
1964+
// Case vec-mat-trans: swap and ready to go.
1965+
if (e.layout({{n}, {n, m}, {m}}))
1966+
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
1967+
return failure();
19421968
}
1943-
rewriter.replaceOp(op, res);
1944-
return success();
1969+
1970+
return failure();
19451971
}
19461972

19471973
LogicalResult

0 commit comments

Comments
 (0)