Skip to content

Commit f239026

Browse files
author
Tobias Gysi
committed
[mlir][linalg][python] Add min operation in OpDSL.
Add the min operation to OpDSL and introduce a min pooling operation to test the implementation. The patch is a sibling of the max operation patch https://reviews.llvm.org/D105203 and the min operation is again lowered to a compare and select pair. Differential Revision: https://reviews.llvm.org/D105345
1 parent 7c5d654 commit f239026

File tree

8 files changed

+280
-27
lines changed

8 files changed

+280
-27
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,77 @@ structured_op: !LinalgStructuredOpConfig
664664
- !ScalarExpression
665665
scalar_arg: I
666666
--- !LinalgOpConfig
667+
metadata: !LinalgOpMetadata
668+
name: pooling_nhwc_min_poly
669+
cpp_class_name: PoolingNhwcMinPolyOp
670+
doc: |-
671+
Performs min pooling.
672+
673+
Numeric casting is performed on the input operand, promoting it to the same
674+
data type as the accumulator/output.
675+
structured_op: !LinalgStructuredOpConfig
676+
args:
677+
- !LinalgOperandDefConfig
678+
name: I
679+
usage: InputOperand
680+
type_var: T1
681+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
682+
(s0, s1, s2, s3)>
683+
- !LinalgOperandDefConfig
684+
name: K
685+
usage: InputOperand
686+
type_var: T2
687+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
688+
(s4, s5)>
689+
- !LinalgOperandDefConfig
690+
name: O
691+
usage: OutputOperand
692+
type_var: U
693+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
694+
(s0, s6, s7, s3)>
695+
- !LinalgOperandDefConfig
696+
name: strides
697+
usage: IndexAttribute
698+
type_var: I64
699+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
700+
-> (s8, s9)>
701+
- !LinalgOperandDefConfig
702+
name: dilations
703+
usage: IndexAttribute
704+
type_var: I64
705+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
706+
-> (s10, s11)>
707+
indexing_maps: !LinalgIndexingMapsConfig
708+
static_indexing_maps:
709+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
710+
s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
711+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
712+
s10, s11] -> (d3, d4)>
713+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
714+
s10, s11] -> (d0, d1, d2, d5)>
715+
iterator_types:
716+
- parallel
717+
- parallel
718+
- parallel
719+
- reduction
720+
- reduction
721+
- parallel
722+
assignments:
723+
- !ScalarAssign
724+
arg: O
725+
value: !ScalarExpression
726+
scalar_apply:
727+
fn_name: min
728+
operands:
729+
- !ScalarExpression
730+
scalar_arg: O
731+
- !ScalarExpression
732+
symbolic_cast:
733+
type_var: U
734+
operands:
735+
- !ScalarExpression
736+
scalar_arg: I
737+
--- !LinalgOpConfig
667738
metadata: !LinalgOpMetadata
668739
name: fill_rng_2d
669740
cpp_class_name: FillRng2DOp

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,18 @@ class RegionBuilderHelper {
275275
}
276276

277277
Value applyfn__max(Value lhs, Value rhs) {
278-
OpBuilder builder = getBuilder();
279-
if (isFloatingPoint(lhs)) {
280-
Value condition =
281-
builder.create<CmpFOp>(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs);
282-
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
283-
}
284-
if (isInteger(lhs)) {
285-
Value condition =
286-
builder.create<CmpIOp>(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs);
287-
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
288-
}
278+
if (isFloatingPoint(lhs))
279+
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
280+
if (isInteger(lhs))
281+
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
282+
llvm_unreachable("unsupported non numeric type");
283+
}
284+
285+
Value applyfn__min(Value lhs, Value rhs) {
286+
if (isFloatingPoint(lhs))
287+
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
288+
if (isInteger(lhs))
289+
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
289290
llvm_unreachable("unsupported non numeric type");
290291
}
291292

@@ -322,6 +323,17 @@ class RegionBuilderHelper {
322323
MLIRContext *context;
323324
Block &block;
324325

326+
Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
327+
OpBuilder builder = getBuilder();
328+
Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
329+
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
330+
}
331+
Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
332+
OpBuilder builder = getBuilder();
333+
Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
334+
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
335+
}
336+
325337
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
326338
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
327339

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class PrimFn:
339339
log = PrimFnType("log")
340340
mul = PrimFnType("mul")
341341
max = PrimFnType("max")
342+
min = PrimFnType("min")
342343
sub = PrimFnType("sub")
343344

344345

@@ -364,6 +365,7 @@ class ReduceFn:
364365
add = PrimFn.add.reduce
365366
mul = PrimFn.mul.reduce
366367
max = PrimFn.max.reduce
368+
min = PrimFn.min.reduce
367369

368370

369371
class PrimApply(TensorExpression):

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,23 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
308308
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
309309

310310
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
311-
i1 = IntegerType.get_signless(1)
312311
if _is_floating_point_type(lhs.type):
313312
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
314-
cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result
315-
return std.SelectOp(lhs.type, cond, lhs, rhs).result
313+
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
316314
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
317315
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
318-
cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result
319-
return std.SelectOp(lhs.type, cond, lhs, rhs).result
316+
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
320317
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
321318

319+
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
320+
if _is_floating_point_type(lhs.type):
321+
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
322+
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
323+
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
324+
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
325+
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
326+
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
327+
322328

323329
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
324330
in_arg_defs: Sequence[OperandDefConfig],
@@ -397,3 +403,13 @@ def _get_floating_point_width(t: Type) -> int:
397403
if BF16Type.isinstance(t):
398404
return 16
399405
raise NotImplementedError(f"Unhandled floating point type switch {t}")
406+
407+
408+
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
409+
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
410+
return std.SelectOp(lhs.type, cond, lhs, rhs).result
411+
412+
413+
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
414+
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
415+
return std.SelectOp(lhs.type, cond, lhs, rhs).result

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,24 @@ def pooling_nhwc_max_poly(
166166
D.c]))
167167

168168

169+
@linalg_structured_op
170+
def pooling_nhwc_min_poly(
171+
I=TensorDef(T1, S.N, S.H, S.W, S.C),
172+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
173+
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
174+
strides=AttributeDef(S.SH, S.SW),
175+
dilations=AttributeDef(S.DH, S.DW)):
176+
"""Performs min pooling.
177+
178+
Numeric casting is performed on the input operand, promoting it to the same
179+
data type as the accumulator/output.
180+
"""
181+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
182+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
183+
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
184+
D.c]))
185+
186+
169187
@linalg_structured_op
170188
def fill_rng_2d(
171189
min=ScalarDef(F64),

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,36 @@ func @generalize_pooling_nhwc_max_poly_i32(%input : tensor<1x4x16x1xi32>, %shape
9090

9191
// -----
9292

93+
func @generalize_pooling_nhwc_min_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
94+
%0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
95+
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
96+
return %0: tensor<1x2x4x1xf32>
97+
}
98+
99+
// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_f32
100+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
101+
// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
102+
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
103+
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
104+
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
105+
106+
// -----
107+
108+
func @generalize_pooling_nhwc_min_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
109+
%0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
110+
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
111+
return %0: tensor<1x2x4x1xi32>
112+
}
113+
114+
// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_i32
115+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
116+
// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
117+
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
118+
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
119+
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
120+
121+
// -----
122+
93123
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
94124
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
95125
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>

mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def conv_poly(
4343

4444

4545
@linalg_structured_op
46-
def pooling_poly(
46+
def pooling_max_poly(
4747
I=TensorDef(T1, S.N, S.H, S.W, S.C),
4848
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
4949
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
@@ -55,6 +55,19 @@ def pooling_poly(
5555
D.c]))
5656

5757

58+
@linalg_structured_op
59+
def pooling_min_poly(
60+
I=TensorDef(T1, S.N, S.H, S.W, S.C),
61+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
62+
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
63+
strides=AttributeDef(S.SH, S.SW),
64+
dilations=AttributeDef(S.DH, S.DW)):
65+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
66+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
67+
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
68+
D.c]))
69+
70+
5871
@linalg_structured_op
5972
def fill_rng_poly(
6073
min=ScalarDef(F64),
@@ -216,7 +229,7 @@ def test_f32i32_conv(input, filter, init_result):
216229
return conv_poly(
217230
input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
218231

219-
# CHECK-LABEL: @test_f32i32_pooling
232+
# CHECK-LABEL: @test_f32i32_max_pooling
220233
# CHECK: linalg.generic
221234
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
222235
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
@@ -229,11 +242,11 @@ def test_f32i32_conv(input, filter, init_result):
229242
@builtin.FuncOp.from_py_func(
230243
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
231244
RankedTensorType.get((2, 4), i32))
232-
def test_f32i32_pooling(input, shape, init_result):
233-
return pooling_poly(
245+
def test_f32i32_max_pooling(input, shape, init_result):
246+
return pooling_max_poly(
234247
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
235248

236-
# CHECK-LABEL: @test_f32f32_pooling
249+
# CHECK-LABEL: @test_f32f32_max_pooling
237250
# CHECK: linalg.generic
238251
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
239252
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
@@ -245,8 +258,26 @@ def test_f32i32_pooling(input, shape, init_result):
245258
@builtin.FuncOp.from_py_func(
246259
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
247260
RankedTensorType.get((2, 4), f32))
248-
def test_f32f32_pooling(input, shape, init_result):
249-
return pooling_poly(
261+
def test_f32f32_max_pooling(input, shape, init_result):
262+
return pooling_max_poly(
263+
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
264+
265+
# CHECK-LABEL: @test_f32i32_min_pooling
266+
# CHECK: = cmpi slt,
267+
@builtin.FuncOp.from_py_func(
268+
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
269+
RankedTensorType.get((2, 4), i32))
270+
def test_f32i32_min_pooling(input, shape, init_result):
271+
return pooling_min_poly(
272+
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
273+
274+
# CHECK-LABEL: @test_f32f32_min_pooling
275+
# CHECK: = cmpf olt,
276+
@builtin.FuncOp.from_py_func(
277+
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
278+
RankedTensorType.get((2, 4), f32))
279+
def test_f32f32_min_pooling(input, shape, init_result):
280+
return pooling_min_poly(
250281
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
251282

252283
# CHECK-LABEL: @test_i32_fill_rng

0 commit comments

Comments
 (0)