@@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) {
124
124
// / Default callback to build a region with a 'vector.yield' terminator with no
125
125
// / arguments.
126
126
void mlir::vector::buildTerminatedBody (OpBuilder &builder, Location loc) {
127
- builder. create < vector::YieldOp>( loc);
127
+ vector::YieldOp::create (builder, loc);
128
128
}
129
129
130
130
// Helper for verifying combining kinds in contractions and reductions.
@@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction
596
596
VectorType newMaskType =
597
597
VectorType::get (dstVecType.getShape (), rewriter.getI1Type (),
598
598
dstVecType.getScalableDims ());
599
- mask = rewriter. create < vector::ShapeCastOp>( loc, newMaskType, mask);
599
+ mask = vector::ShapeCastOp::create (rewriter, loc, newMaskType, mask);
600
600
}
601
- cast = rewriter. create < vector::ShapeCastOp> (
602
- loc, reductionOp.getDestType (), reductionOp.getSource ());
601
+ cast = vector::ShapeCastOp::create (
602
+ rewriter, loc, reductionOp.getDestType (), reductionOp.getSource ());
603
603
} else {
604
604
// This means we are reducing all the dimensions, and all reduction
605
605
// dimensions are of size 1. So a simple extraction would do.
606
606
if (mask)
607
- mask = rewriter. create < vector::ExtractOp>( loc, mask);
608
- cast = rewriter. create < vector::ExtractOp>( loc, reductionOp.getSource ());
607
+ mask = vector::ExtractOp::create (rewriter, loc, mask);
608
+ cast = vector::ExtractOp::create (rewriter, loc, reductionOp.getSource ());
609
609
}
610
610
611
611
Value result =
@@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
672
672
switch (op) {
673
673
case arith::AtomicRMWKind::addf:
674
674
case arith::AtomicRMWKind::addi:
675
- return builder. create < vector::ReductionOp>( vector.getLoc (),
676
- CombiningKind::ADD, vector);
675
+ return vector::ReductionOp::create (builder, vector.getLoc (),
676
+ CombiningKind::ADD, vector);
677
677
case arith::AtomicRMWKind::mulf:
678
678
case arith::AtomicRMWKind::muli:
679
- return builder. create < vector::ReductionOp>( vector.getLoc (),
680
- CombiningKind::MUL, vector);
679
+ return vector::ReductionOp::create (builder, vector.getLoc (),
680
+ CombiningKind::MUL, vector);
681
681
case arith::AtomicRMWKind::minimumf:
682
- return builder. create < vector::ReductionOp>( vector.getLoc (),
683
- CombiningKind::MINIMUMF, vector);
682
+ return vector::ReductionOp::create (builder, vector.getLoc (),
683
+ CombiningKind::MINIMUMF, vector);
684
684
case arith::AtomicRMWKind::mins:
685
- return builder. create < vector::ReductionOp>( vector.getLoc (),
686
- CombiningKind::MINSI, vector);
685
+ return vector::ReductionOp::create (builder, vector.getLoc (),
686
+ CombiningKind::MINSI, vector);
687
687
case arith::AtomicRMWKind::minu:
688
- return builder. create < vector::ReductionOp>( vector.getLoc (),
689
- CombiningKind::MINUI, vector);
688
+ return vector::ReductionOp::create (builder, vector.getLoc (),
689
+ CombiningKind::MINUI, vector);
690
690
case arith::AtomicRMWKind::maximumf:
691
- return builder. create < vector::ReductionOp>( vector.getLoc (),
692
- CombiningKind::MAXIMUMF, vector);
691
+ return vector::ReductionOp::create (builder, vector.getLoc (),
692
+ CombiningKind::MAXIMUMF, vector);
693
693
case arith::AtomicRMWKind::maxs:
694
- return builder. create < vector::ReductionOp>( vector.getLoc (),
695
- CombiningKind::MAXSI, vector);
694
+ return vector::ReductionOp::create (builder, vector.getLoc (),
695
+ CombiningKind::MAXSI, vector);
696
696
case arith::AtomicRMWKind::maxu:
697
- return builder. create < vector::ReductionOp>( vector.getLoc (),
698
- CombiningKind::MAXUI, vector);
697
+ return vector::ReductionOp::create (builder, vector.getLoc (),
698
+ CombiningKind::MAXUI, vector);
699
699
case arith::AtomicRMWKind::andi:
700
- return builder. create < vector::ReductionOp>( vector.getLoc (),
701
- CombiningKind::AND, vector);
700
+ return vector::ReductionOp::create (builder, vector.getLoc (),
701
+ CombiningKind::AND, vector);
702
702
case arith::AtomicRMWKind::ori:
703
- return builder. create < vector::ReductionOp>( vector.getLoc (),
704
- CombiningKind::OR, vector);
703
+ return vector::ReductionOp::create (builder, vector.getLoc (),
704
+ CombiningKind::OR, vector);
705
705
// TODO: Add remaining reduction operations.
706
706
default :
707
707
(void )emitOptionalError (loc, " Reduction operation type not supported" );
@@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
740
740
741
741
Location loc = reductionOp.getLoc ();
742
742
if (mask)
743
- mask = rewriter. create < ExtractOp>( loc, mask);
744
- Value result = rewriter. create < ExtractOp>( loc, reductionOp.getVector ());
743
+ mask = ExtractOp::create (rewriter, loc, mask);
744
+ Value result = ExtractOp::create (rewriter, loc, reductionOp.getVector ());
745
745
746
746
if (Value acc = reductionOp.getAcc ())
747
747
result = vector::makeArithReduction (rewriter, loc, reductionOp.getKind (),
@@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final
4172
4172
// greater than the vector dim size.
4173
4173
IntegerAttr offsetAttr =
4174
4174
rewriter.getIntegerAttr (maskDimSize.getType (), sliceOffset);
4175
- Value offset = rewriter. create < arith::ConstantOp>( loc, offsetAttr);
4175
+ Value offset = arith::ConstantOp::create (rewriter, loc, offsetAttr);
4176
4176
Value sliceMaskDimSize =
4177
- rewriter. create < arith::SubIOp>( loc, maskDimSize, offset);
4177
+ arith::SubIOp::create (rewriter, loc, maskDimSize, offset);
4178
4178
sliceMaskDimSizes.push_back (sliceMaskDimSize);
4179
4179
}
4180
4180
// Add unchanged dimensions.
@@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final
4289
4289
sizes[i] = 1 ;
4290
4290
}
4291
4291
}
4292
- source = rewriter. create < ExtractStridedSliceOp> (
4293
- op->getLoc (), source, offsets, sizes,
4292
+ source = ExtractStridedSliceOp::create (
4293
+ rewriter, op->getLoc (), source, offsets, sizes,
4294
4294
getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4295
4295
}
4296
4296
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), source);
@@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final
4382
4382
4383
4383
SmallVector<int64_t > offsets = getI64SubArray (op.getOffsets ());
4384
4384
auto extractOffsets = ArrayRef (offsets).take_front (numOffsets);
4385
- Value extract = rewriter. create < vector::ExtractOp>( op->getLoc (), source,
4386
- extractOffsets);
4385
+ Value extract = vector::ExtractOp::create (rewriter, op->getLoc (), source,
4386
+ extractOffsets);
4387
4387
rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, op.getType (), extract);
4388
4388
return success ();
4389
4389
}
@@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4413
4413
4414
4414
Type elemType = llvm::cast<ShapedType>(source.getType ()).getElementType ();
4415
4415
if (!padding)
4416
- padding = builder. create < ub::PoisonOp>( result.location , elemType);
4416
+ padding = ub::PoisonOp::create (builder, result.location , elemType);
4417
4417
build (builder, result, vectorType, source, indices, permutationMapAttr,
4418
4418
*padding, /* mask=*/ Value (), inBoundsAttr);
4419
4419
}
@@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4431
4431
SmallVector<bool >(vectorType.getRank (), false ));
4432
4432
Type elemType = llvm::cast<ShapedType>(source.getType ()).getElementType ();
4433
4433
if (!padding)
4434
- padding = builder. create < ub::PoisonOp>( result.location , elemType);
4434
+ padding = ub::PoisonOp::create (builder, result.location , elemType);
4435
4435
build (builder, result, vectorType, source, indices, *padding,
4436
4436
permutationMapAttr, inBoundsAttr);
4437
4437
}
@@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4450
4450
SmallVector<bool >(vectorType.getRank (), false ));
4451
4451
Type elemType = llvm::cast<ShapedType>(source.getType ()).getElementType ();
4452
4452
if (!padding)
4453
- padding = builder. create < ub::PoisonOp>( result.location , elemType);
4453
+ padding = ub::PoisonOp::create (builder, result.location , elemType);
4454
4454
build (builder, result, vectorType, source, indices, permutationMapAttr,
4455
4455
*padding,
4456
4456
/* mask=*/ Value (), inBoundsAttr);
@@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast
4975
4975
VectorType broadcastedType = VectorType::get (
4976
4976
broadcastShape, defWrite.getVectorType ().getElementType (),
4977
4977
broadcastScalableFlags);
4978
- vec = rewriter. create < vector::BroadcastOp>( loc, broadcastedType, vec);
4978
+ vec = vector::BroadcastOp::create (rewriter, loc, broadcastedType, vec);
4979
4979
SmallVector<int64_t > transposePerm (permutation.begin (), permutation.end ());
4980
4980
rewriter.replaceOpWithNewOp <vector::TransposeOp>(readOp, vec,
4981
4981
transposePerm);
@@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite
5453
5453
// Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5454
5454
// Set all in_bounds to false and let the folder infer them.
5455
5455
SmallVector<bool > newInBounds (vectorShape.size (), false );
5456
- auto newExtractOp = rewriter.create <tensor::ExtractSliceOp>(
5457
- extractOp.getLoc (), insertOp.getSourceType (), insertOp.getDest (),
5458
- insertOp.getMixedOffsets (), insertOp.getMixedSizes (),
5459
- insertOp.getMixedStrides ());
5460
- auto newTransferWriteOp = rewriter.create <TransferWriteOp>(
5461
- transferOp.getLoc (), transferOp.getVector (), newExtractOp.getResult (),
5462
- transferOp.getIndices (), transferOp.getPermutationMapAttr (),
5456
+ auto newExtractOp = tensor::ExtractSliceOp::create (
5457
+ rewriter, extractOp.getLoc (), insertOp.getSourceType (),
5458
+ insertOp.getDest (), insertOp.getMixedOffsets (),
5459
+ insertOp.getMixedSizes (), insertOp.getMixedStrides ());
5460
+ auto newTransferWriteOp = TransferWriteOp::create (
5461
+ rewriter, transferOp.getLoc (), transferOp.getVector (),
5462
+ newExtractOp.getResult (), transferOp.getIndices (),
5463
+ transferOp.getPermutationMapAttr (),
5463
5464
rewriter.getBoolArrayAttr (newInBounds));
5464
5465
rewriter.modifyOpInPlace (insertOp, [&]() {
5465
5466
insertOp.getSourceMutable ().assign (newTransferWriteOp.getResult ());
@@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
6983
6984
OpBuilder opBuilder (builder.getContext ());
6984
6985
Operation *maskedOp = &block.front ();
6985
6986
opBuilder.setInsertionPointToEnd (&block);
6986
- opBuilder. create < vector::YieldOp>( loc, maskedOp->getResults ());
6987
+ vector::YieldOp::create (opBuilder, loc, maskedOp->getResults ());
6987
6988
}
6988
6989
6989
6990
LogicalResult MaskOp::verify () {
@@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7318
7319
// Create a block and move the op to that block.
7319
7320
insBlock->getOperations ().splice (
7320
7321
insBlock->begin (), maskableOp->getBlock ()->getOperations (), maskableOp);
7321
- builder. create < YieldOp>( maskableOp->getLoc (), maskableOp->getResults ());
7322
+ YieldOp::create (builder, maskableOp->getLoc (), maskableOp->getResults ());
7322
7323
}
7323
7324
7324
7325
// / Creates a vector.mask operation around a maskable operation. Returns the
@@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
7330
7331
if (!mask)
7331
7332
return maskableOp;
7332
7333
if (passthru)
7333
- return builder. create < MaskOp>( maskableOp->getLoc (),
7334
- maskableOp->getResultTypes (), mask, passthru,
7335
- maskableOp, createMaskOpRegion);
7336
- return builder. create < MaskOp>( maskableOp->getLoc (),
7337
- maskableOp->getResultTypes (), mask, maskableOp,
7338
- createMaskOpRegion);
7334
+ return MaskOp::create (builder, maskableOp->getLoc (),
7335
+ maskableOp->getResultTypes (), mask, passthru,
7336
+ maskableOp, createMaskOpRegion);
7337
+ return MaskOp::create (builder, maskableOp->getLoc (),
7338
+ maskableOp->getResultTypes (), mask, maskableOp,
7339
+ createMaskOpRegion);
7339
7340
}
7340
7341
7341
7342
// / Creates a vector select operation that picks values from `newValue` or
@@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7350
7351
if (!mask)
7351
7352
return newValue;
7352
7353
7353
- return builder. create < arith::SelectOp>( newValue.getLoc (), newValue.getType (),
7354
- mask, newValue, passthru);
7354
+ return arith::SelectOp::create (builder, newValue.getLoc (), newValue.getType (),
7355
+ mask, newValue, passthru);
7355
7356
}
7356
7357
7357
7358
// ===----------------------------------------------------------------------===//
0 commit comments