Skip to content

Commit 4525d52

Browse files
committed
Revert "[mlir][Vector] NFC - Compress vector to outerproduct lowering."
This reverts commit db188ad. Breaks the GCC tests, likely because of some order of evaluation difference between clang and gcc.
1 parent c7c5a1c commit 4525d52

File tree

1 file changed

+89
-115
lines changed

1 file changed

+89
-115
lines changed

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 89 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,72 +1816,6 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
18161816
return success();
18171817
}
18181818

1819-
namespace {
1820-
struct IteratorType {
1821-
IteratorType(StringRef strRef) : strRef(strRef) {}
1822-
bool isOfType(Attribute attr) const {
1823-
auto sAttr = attr.dyn_cast<StringAttr>();
1824-
return sAttr && sAttr.getValue() == strRef;
1825-
}
1826-
StringRef strRef;
1827-
};
1828-
struct Par : public IteratorType {
1829-
Par() : IteratorType(getParallelIteratorTypeName()) {}
1830-
};
1831-
struct Red : public IteratorType {
1832-
Red() : IteratorType(getReductionIteratorTypeName()) {}
1833-
};
1834-
1835-
// Unroll outer-products along reduction.
1836-
struct UnrolledOuterProductEmitter {
1837-
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1838-
1839-
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
1840-
vector::ContractionOp op)
1841-
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
1842-
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
1843-
1844-
Value t(Value v) {
1845-
static constexpr std::array<int64_t, 2> perm = {1, 0};
1846-
return rewriter.create<vector::TransposeOp>(loc, v, perm);
1847-
}
1848-
1849-
bool iters(ArrayRef<IteratorType> its) {
1850-
if (its.size() != iterators.size())
1851-
return false;
1852-
for (int i = 0, e = its.size(); i != e; ++i) {
1853-
if (!its[i].isOfType(iterators[i]))
1854-
return false;
1855-
}
1856-
return true;
1857-
}
1858-
1859-
bool layout(MapList l) {
1860-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1861-
return maps == infer(l);
1862-
}
1863-
1864-
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
1865-
assert(reductionSize > 0);
1866-
for (int64_t k = 0; k < reductionSize; ++k) {
1867-
Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
1868-
Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
1869-
res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1870-
res, kind);
1871-
}
1872-
rewriter.replaceOp(op, res);
1873-
return success();
1874-
}
1875-
1876-
PatternRewriter &rewriter;
1877-
Location loc;
1878-
vector::CombiningKind kind;
1879-
ArrayAttr iterators;
1880-
SmallVector<AffineMap, 4> maps;
1881-
Operation *op;
1882-
};
1883-
} // namespace
1884-
18851819
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
18861820
/// semantics to a reduction_size-unrolled sequence:
18871821
/// ```
@@ -1910,64 +1844,104 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
19101844
if (failed(filter(op)))
19111845
return failure();
19121846

1847+
Location loc = op.getLoc();
1848+
int64_t reductionSize = 0;
19131849
VectorType lhsType = op.getLhsType();
19141850
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
19151851

19161852
// Set up the parallel/reduction structure in right form.
1853+
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1854+
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
19171855
AffineExpr m, n, k;
19181856
bindDims(rewriter.getContext(), m, n, k);
1919-
1920-
//
1921-
// Two outer parallel, one inner reduction (matmat flavor).
1922-
//
1923-
UnrolledOuterProductEmitter e(rewriter, op);
1924-
if (e.iters({Par(), Par(), Red()})) {
1925-
// Classical row-major matmul: Just permute the lhs.
1926-
if (e.layout({{m, k}, {k, n}, {m, n}}))
1927-
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
1928-
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1929-
if (e.layout({{m, k}, {n, k}, {m, n}}))
1930-
return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1));
1931-
// No need to permute anything.
1932-
if (e.layout({{k, m}, {k, n}, {m, n}}))
1933-
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
1934-
// Just permute the rhs.
1935-
if (e.layout({{k, m}, {n, k}, {m, n}}))
1936-
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
1937-
// Transposed output: swap RHS and LHS.
1938-
// Classical row-major matmul: permute the lhs.
1939-
if (e.layout({{m, k}, {k, n}, {n, m}}))
1940-
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
1941-
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1942-
if (e.layout({{m, k}, {n, k}, {n, m}}))
1943-
return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1));
1944-
if (e.layout({{k, m}, {k, n}, {n, m}}))
1945-
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
1946-
if (e.layout({{k, m}, {n, k}, {n, m}}))
1947-
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
1857+
static constexpr std::array<int64_t, 2> perm = {1, 0};
1858+
auto iteratorTypes = op.iterator_types().getValue();
1859+
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1860+
if (isParallelIterator(iteratorTypes[0]) &&
1861+
isParallelIterator(iteratorTypes[1]) &&
1862+
isReductionIterator(iteratorTypes[2])) {
1863+
//
1864+
// Two outer parallel, one inner reduction (matmat flavor).
1865+
//
1866+
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1867+
// This is the classical row-major matmul. Just permute the lhs.
1868+
reductionSize = lhsType.getDimSize(1);
1869+
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1870+
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1871+
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1872+
reductionSize = lhsType.getDimSize(1);
1873+
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1874+
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1875+
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1876+
// No need to permute anything.
1877+
reductionSize = lhsType.getDimSize(0);
1878+
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1879+
// Just permute the rhs.
1880+
reductionSize = lhsType.getDimSize(0);
1881+
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1882+
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1883+
// This is the classical row-major matmul. Just permute the lhs.
1884+
reductionSize = lhsType.getDimSize(1);
1885+
Value tmp = rhs;
1886+
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1887+
lhs = tmp;
1888+
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1889+
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
1890+
reductionSize = lhsType.getDimSize(1);
1891+
Value tmp = rhs;
1892+
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1893+
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1894+
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1895+
// No need to permute anything, but still swap lhs and rhs.
1896+
reductionSize = lhsType.getDimSize(0);
1897+
std::swap(lhs, rhs);
1898+
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1899+
// Just permute the rhs.
1900+
reductionSize = lhsType.getDimSize(0);
1901+
Value tmp = lhs;
1902+
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1903+
rhs = tmp;
1904+
} else {
1905+
return failure();
1906+
}
1907+
} else if (isParallelIterator(iteratorTypes[0]) &&
1908+
isReductionIterator(iteratorTypes[1])) {
1909+
//
1910+
// One outer parallel, one inner reduction (matvec flavor)
1911+
//
1912+
if (maps == infer({{m, n}, {n}, {m}})) {
1913+
// Case mat-vec: transpose.
1914+
reductionSize = lhsType.getDimSize(1);
1915+
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1916+
} else if (maps == infer({{n, m}, {n}, {m}})) {
1917+
// Case mat-trans-vec: ready to go.
1918+
reductionSize = lhsType.getDimSize(0);
1919+
} else if (maps == infer({{n}, {m, n}, {m}})) {
1920+
// Case vec-mat: swap and transpose.
1921+
reductionSize = lhsType.getDimSize(0);
1922+
std::swap(lhs, rhs);
1923+
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1924+
} else if (maps == infer({{n}, {n, m}, {m}})) {
1925+
// Case vec-mat-trans: swap and ready to go.
1926+
reductionSize = lhsType.getDimSize(0);
1927+
std::swap(lhs, rhs);
1928+
} else {
1929+
return failure();
1930+
}
1931+
} else {
19481932
return failure();
19491933
}
1950-
1951-
//
1952-
// One outer parallel, one inner reduction (matvec flavor)
1953-
//
1954-
if (e.iters({Par(), Red()})) {
1955-
// Case mat-vec: transpose.
1956-
if (e.layout({{m, n}, {n}, {m}}))
1957-
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
1958-
// Case mat-trans-vec: ready to go.
1959-
if (e.layout({{n, m}, {n}, {m}}))
1960-
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
1961-
// Case vec-mat: swap and transpose.
1962-
if (e.layout({{n}, {m, n}, {m}}))
1963-
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
1964-
// Case vec-mat-trans: swap and ready to go.
1965-
if (e.layout({{n}, {n, m}, {m}}))
1966-
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
1967-
return failure();
1934+
assert(reductionSize > 0);
1935+
1936+
// Unroll outer-products along reduction.
1937+
for (int64_t k = 0; k < reductionSize; ++k) {
1938+
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1939+
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1940+
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
1941+
b, res, op.kind());
19681942
}
1969-
1970-
return failure();
1943+
rewriter.replaceOp(op, res);
1944+
return success();
19711945
}
19721946

19731947
LogicalResult

0 commit comments

Comments
 (0)