Skip to content

Commit a9d1fea

Browse files
authored
Fix the condition for peeling the first iteration (#86350)
This PR fixes the condition used in loop peeling of the first iteration. Using ceilDiv instead of floorDiv when calculating the loop counts, so that the first iteration gets peeled as needed.
1 parent 7b3e943 commit a9d1fea

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
220220
auto stepInt = getConstantIntValue(forOp.getStep());
221221

222222
// Peeling is not needed if there is one or less iteration.
223-
if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) / *stepInt <= 1)
223+
if (lbInt && ubInt && stepInt && ceil(float(*ubInt - *lbInt) / *stepInt) <= 1)
224224
return failure();
225225

226226
AffineExpr lbSymbol, stepSymbol;

mlir/test/Dialect/SCF/for-loop-peeling-front.mlir

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
// CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
1414
// CHECK: scf.yield %[[INIT]]
1515
// CHECK: }
16-
// CHECK: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C4]] to %[[C17]]
17-
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[FIRST]]) -> (i32) {
18-
// CHECK: %[[MIN2:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
16+
// CHECK: %[[RESULT:.*]] = scf.for %[[IV2:.*]] = %[[C4]] to %[[C17]]
17+
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC2:.*]] = %[[FIRST]]) -> (i32) {
18+
// CHECK: %[[MIN2:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV2]])[%[[C4]]]
1919
// CHECK: %[[CAST2:.*]] = arith.index_cast %[[MIN2]] : index to i32
20-
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST2]] : i32
20+
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC2]], %[[CAST2]] : i32
2121
// CHECK: scf.yield %[[ADD]]
2222
// CHECK: }
2323
// CHECK: return %[[RESULT]]
@@ -110,6 +110,45 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
110110

111111
// -----
112112

113+
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
114+
// CHECK: func @two_iteration_example(
115+
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
116+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
117+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
118+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
119+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
120+
// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C2]] to %[[C6]]
121+
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
122+
// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C6]], %[[IV]])[%[[C4]]]
123+
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
124+
// CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
125+
// CHECK: scf.yield %[[INIT]]
126+
// CHECK: }
127+
// CHECK: %[[RESULT:.*]] = scf.for %[[IV2:.*]] = %[[C6]] to %[[C8]]
128+
// CHECK-SAME: step %[[C4]] iter_args(%[[ACC2:.*]] = %[[FIRST]]) -> (i32) {
129+
// CHECK: %[[MIN2:.*]] = affine.min #[[MAP]](%[[C8]], %[[IV2]])[%[[C4]]]
130+
// CHECK: %[[CAST2:.*]] = arith.index_cast %[[MIN2]] : index to i32
131+
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC2]], %[[CAST2]] : i32
132+
// CHECK: scf.yield %[[ADD]]
133+
// CHECK: }
134+
// CHECK: return %[[RESULT]]
135+
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
136+
func.func @two_iteration_example() -> i32 {
137+
%c0_i32 = arith.constant 0 : i32
138+
%lb = arith.constant 2 : index
139+
%step = arith.constant 4 : index
140+
%ub = arith.constant 8 : index
141+
%r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
142+
%s = affine.min #map(%ub, %iv)[%step]
143+
%casted = arith.index_cast %s : index to i32
144+
%0 = arith.addi %arg, %casted : i32
145+
scf.yield %0 : i32
146+
}
147+
return %r : i32
148+
}
149+
150+
// -----
151+
113152
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
114153
// CHECK: func @no_peeling_front(
115154
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32

0 commit comments

Comments
 (0)