@@ -1718,11 +1718,9 @@ static bool isBroadcastLike(Operation *op) {
1718
1718
return false ;
1719
1719
1720
1720
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721
- // Condition 1: dst has hight rank.
1722
- // Condition 2: src shape is a suffix of dst shape.
1723
- //
1724
1721
// Note that checking that dst shape has a prefix of 1s is not sufficient,
1725
- // for example (2,3) -> (1,3,2) is not broadcast-like.
1722
+ // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
1723
+ // is that the source shape is a suffix of the destination shape.
1726
1724
VectorType srcType = shapeCast.getSourceVectorType ();
1727
1725
ArrayRef<int64_t > srcShape = srcType.getShape ();
1728
1726
uint64_t srcRank = srcType.getRank ();
@@ -1734,16 +1732,16 @@ static bool isBroadcastLike(Operation *op) {
1734
1732
// /
1735
1733
// / Example:
1736
1734
// /
1737
- // / broadcast extract
1738
- // / (3, 4) --------> (2, 3, 4) ------> (4)
1735
+ // / broadcast extract [1][2]
1736
+ // / (3, 4) --------> (2, 3, 4) ---------------- > (4)
1739
1737
// /
1740
1738
// / becomes
1741
- // / extract
1742
- // / (3,4) ---------------------------> (4)
1739
+ // / extract [1]
1740
+ // / (3,4) ------------------------------------- > (4)
1743
1741
// /
1744
1742
// /
1745
- // / The variable names used in this implementation use names which correspond to
1746
- // / the above shapes as,
1743
+ // / The variable names used in this implementation correspond to the above
1744
+ // / shapes as,
1747
1745
// /
1748
1746
// / - (3, 4) is `input` shape.
1749
1747
// / - (2, 3, 4) is `broadcast` shape.
@@ -1775,14 +1773,15 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1775
1773
if (extractRank > inputRank)
1776
1774
return Value ();
1777
1775
1778
- // Proof by contradiction that, at this point, input is a vector.
1779
- // Suppose input is a scalar.
1780
- // ==> inputRank is 0.
1781
- // ==> extractRank is 0 (because extractRank <= inputRank).
1782
- // ==> extract is scalar (because rank-0 extraction is always scalar).
1783
- // ==> input and extract are scalar, so same type.
1784
- // ==> returned early (check same type).
1785
- // Contradiction!
1776
+ // The above condition guarantees that input is a vector:
1777
+ //
1778
+ // If input is a scalar:
1779
+ // 1) inputRank is 0, so
1780
+ // 2) extractRank is 0 (because extractRank <= inputRank), so
1781
+ // 3) extract is scalar (because rank-0 extraction is always scalar), s0
1782
+ // 4) input and extract are scalar, so same type.
1783
+ // But then we should have returned earlier when the types were compared for
1784
+ // equivalence. So input is not a scalar at this point.
1786
1785
assert (inputType && " input must be a vector type because of previous checks" );
1787
1786
ArrayRef<int64_t > inputShape = inputType.getShape ();
1788
1787
0 commit comments