[Codegen] Upgrade Transforms and Utils to free create functions. NFC. (#21882)
The builder create methods are deprecated:
https://mlir.llvm.org/deprecation/. See
https://discourse.llvm.org/t/psa-opty-create-now-with-100-more-tab-complete/87339.
The main benefit of free functions is better tab completion with
LSP/IDE.
I'm splitting the upgrade in chunks going by project directories.
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp b/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
index 81b5f57..d14e5a0 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
@@ -44,11 +44,11 @@
Block *block = op.getBody();
ValueRange initArgs = op.getInitArgs();
Value count =
- rewriter.create<arith::CmpIOp>(op->getLoc(), arith::CmpIPredicate::sgt,
- op.getUpperBound(), op.getLowerBound());
+ arith::CmpIOp::create(rewriter, op->getLoc(), arith::CmpIPredicate::sgt,
+ op.getUpperBound(), op.getLowerBound());
auto ifOp =
- rewriter.create<scf::IfOp>(op->getLoc(), op.getResultTypes(), count,
- /*withElseRegion=*/initArgs.size() != 0);
+ scf::IfOp::create(rewriter, op->getLoc(), op.getResultTypes(), count,
+ /*withElseRegion=*/initArgs.size() != 0);
Operation *terminator = block->getTerminator();
rewriter.inlineBlockBefore(block, &ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin(), blockArgs);
@@ -56,7 +56,7 @@
rewriter.eraseOp(terminator);
} else {
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- rewriter.create<scf::YieldOp>(ifOp.getLoc(), initArgs);
+ scf::YieldOp::create(rewriter, ifOp.getLoc(), initArgs);
}
rewriter.replaceOp(op, ifOp);
}
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index d2b7104..5dc219e 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -131,11 +131,11 @@
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(&funcOp.getFunctionBody().front());
Value allocation =
- builder.create<AllocLikeOpType>(loc, allocLikeType, alignmentAttr);
+ AllocLikeOpType::create(builder, loc, allocLikeType, alignmentAttr);
if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
builder.setInsertionPoint(
funcOp.getFunctionBody().front().getTerminator());
- builder.create<memref::DeallocOp>(loc, allocation);
+ memref::DeallocOp::create(builder, loc, allocation);
}
return allocation;
}
@@ -164,7 +164,7 @@
}
if (!vscale)
- vscale = builder.create<vector::VectorScaleOp>(loc);
+ vscale = vector::VectorScaleOp::create(builder, loc);
return affine::materializeComputedBound(
builder, loc, ub->map, {std::make_pair(vscale, std::nullopt)});
}
@@ -217,26 +217,26 @@
dispatchIndexOpFoldResults(allocSizes, dynamicSizes, staticShape);
auto allocationType = allocLikeType.clone(staticShape);
- allocation = builder.create<AllocLikeOpType>(loc, allocationType,
- dynamicSizes, alignmentAttr);
+ allocation = AllocLikeOpType::create(builder, loc, allocationType,
+ dynamicSizes, alignmentAttr);
}
SmallVector<OpFoldResult> offsets(allocLikeType.getRank(),
builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(allocLikeType.getRank(),
builder.getIndexAttr(1));
- Value subviewOp = builder.create<memref::SubViewOp>(loc, allocation, offsets,
- subviewSizes, strides);
+ Value subviewOp = memref::SubViewOp::create(builder, loc, allocation, offsets,
+ subviewSizes, strides);
// Cast it back to the original types to prevent consumer op's verification
// error. It could happen when the consumer op is a memref.subview op.
if (subviewOp.getType() != allocLikeType) {
- subviewOp = builder.create<memref::CastOp>(loc, allocLikeType, subviewOp);
+ subviewOp = memref::CastOp::create(builder, loc, allocLikeType, subviewOp);
}
if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
builder.setInsertionPoint(funcOp.getFunctionBody().front().getTerminator());
- builder.create<memref::DeallocOp>(loc, allocation);
+ memref::DeallocOp::create(builder, loc, allocation);
}
return subviewOp;
@@ -402,7 +402,7 @@
// time might make these go away.
if (isa<IREE::Codegen::QueryTileSizesOp>(op)) {
Value constVal =
- rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 16);
+ arith::ConstantIndexOp::create(rewriter, op->getLoc(), 16);
for (auto result : op->getResults()) {
map.map(result, constVal);
}
@@ -633,9 +633,10 @@
workgroupMapping->getValue(), mappingAttr->getValue());
auto newMappingAttr = rewriter.getArrayAttr(newMapping);
- auto newForallOp = rewriter.create<scf::ForallOp>(
- forallOp.getLoc(), newLbs, newUbs, newSteps, /*outputs=*/ValueRange{},
- newMappingAttr, [](OpBuilder &, Location, ValueRange) {});
+ auto newForallOp = scf::ForallOp::create(
+ rewriter, forallOp.getLoc(), newLbs, newUbs, newSteps,
+ /*outputs=*/ValueRange{}, newMappingAttr,
+ [](OpBuilder &, Location, ValueRange) {});
Block *oldBlock = forallOp.getBody();
Block *newForallBody = newForallOp.getBody();
SmallVector<Value> newInductionVars = newForallOp.getInductionVars();
@@ -852,16 +853,16 @@
MemRefType allocType = MemRefType::get({maxAlloc}, builder.getI8Type(),
AffineMap(), memorySpace);
Value packedAlloc =
- builder.create<memref::AllocOp>(funcOp.getLoc(), allocType);
+ memref::AllocOp::create(builder, funcOp.getLoc(), allocType);
for (size_t i = 0; i < aliasGroups.size(); i++) {
int64_t offset = 0;
for (Operation *alloc : aliasGroups[i]) {
Location loc = alloc->getLoc();
builder.setInsertionPoint(alloc);
- Value offsetValue = builder.create<arith::ConstantIndexOp>(loc, offset);
- Value newAlloc = builder.create<memref::ViewOp>(
- packedAlloc.getLoc(), alloc->getResultTypes()[0], packedAlloc,
- offsetValue, ArrayRef<Value>({}));
+ Value offsetValue = arith::ConstantIndexOp::create(builder, loc, offset);
+ Value newAlloc = memref::ViewOp::create(
+ builder, packedAlloc.getLoc(), alloc->getResultTypes()[0],
+ packedAlloc, offsetValue, ArrayRef<Value>({}));
offset += getAllocSize(alloc, dataLayout);
alloc->replaceAllUsesWith(ArrayRef<Value>({newAlloc}));
alloc->erase();
@@ -1135,9 +1136,10 @@
// Step 3. Create the ForallOp.
Location loc = forallOp.getLoc();
- scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), loop.getInitArgs(), forallOp.getMappingAttr());
+ scf::ForallOp newForallOp = scf::ForallOp::create(
+ rewriter, loc, forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
+ loop.getInitArgs(), forallOp.getMappingAttr());
{
// RAII guard, inserting within forallOp, before terminator.
@@ -1156,10 +1158,10 @@
}
// Step 4. Create a new for loop with new inits for the result of the
// extracted slices.
- auto newLoop = rewriter.create<scf::ForOp>(
- loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), newInits,
- [](OpBuilder &, Location, Value, ValueRange) {});
+ auto newLoop =
+ scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
+ loop.getUpperBound(), loop.getStep(), newInits,
+ [](OpBuilder &, Location, Value, ValueRange) {});
{
// Step 5. Inline the body of the original forall into the new for loop.
@@ -1195,7 +1197,7 @@
newYields.push_back(parallelSlice.getSource());
}
rewriter.setInsertionPointToEnd(newLoop.getBody());
- rewriter.create<scf::YieldOp>(loop.getLoc(), newYields);
+ scf::YieldOp::create(rewriter, loop.getLoc(), newYields);
}
// Move all producers for the indices of the slices outside of the body
@@ -1215,11 +1217,11 @@
for (auto [parallelSlice, source, dest] :
llvm::zip_equal(terminators, newLoop.getResults(),
newForallOp.getRegionIterArgs())) {
- rewriter.create<tensor::ParallelInsertSliceOp>(
- parallelSlice.getLoc(), source, dest, parallelSlice.getOffsets(),
- parallelSlice.getSizes(), parallelSlice.getStrides(),
- parallelSlice.getStaticOffsets(), parallelSlice.getStaticSizes(),
- parallelSlice.getStaticStrides());
+ tensor::ParallelInsertSliceOp::create(
+ rewriter, parallelSlice.getLoc(), source, dest,
+ parallelSlice.getOffsets(), parallelSlice.getSizes(),
+ parallelSlice.getStrides(), parallelSlice.getStaticOffsets(),
+ parallelSlice.getStaticSizes(), parallelSlice.getStaticStrides());
}
}
@@ -1278,8 +1280,8 @@
}
Location loc = padOp.getLoc();
- auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, tensor::getMixedSizes(rewriter, loc, padOp),
+ auto emptyOp = tensor::EmptyOp::create(
+ rewriter, loc, tensor::getMixedSizes(rewriter, loc, padOp),
resultType.getElementType());
rewriter.replaceOpWithNewOp<linalg::FillOp>(padOp, padValue,
emptyOp.getResult());
diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
index 30ef872..4944e49 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
@@ -63,13 +63,13 @@
// necessary vscale operation and the corresponding static_size * vscale
// values.
SmallVector<OpFoldResult> result(staticTileSizes.size());
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
for (size_t i = 0; i < result.size(); i++) {
if (materializeEncodingInfo.scalableTiles.value()[i]) {
auto staticTileSize =
- rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i]);
+ arith::ConstantIndexOp::create(rewriter, loc, staticTileSizes[i]);
auto scalableInnerTileSize =
- rewriter.create<arith::MulIOp>(loc, staticTileSize, vscale);
+ arith::MulIOp::create(rewriter, loc, staticTileSize, vscale);
result[i] = scalableInnerTileSize.getResult();
} else {
result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]);
@@ -85,8 +85,8 @@
return failure();
}
SmallVector<Type> resultTypes(tensorType.getRank(), rewriter.getIndexType());
- auto op = rewriter.create<IREE::Codegen::QueryTileSizesOp>(
- loc, resultTypes, TypeAttr::get(tensorType));
+ auto op = IREE::Codegen::QueryTileSizesOp::create(rewriter, loc, resultTypes,
+ TypeAttr::get(tensorType));
SmallVector<Value> innerTileSizeValues = op.getResults();
SmallVector<OpFoldResult> result(staticTileSizes.size());
@@ -95,7 +95,7 @@
result[i] = innerTileSizeValues[i];
} else if (tensorType.isDynamicDim(i)) {
result[i] =
- rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i])
+ arith::ConstantIndexOp::create(rewriter, loc, staticTileSizes[i])
.getResult();
} else {
result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]);
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 03b12c6..d542d06 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -64,9 +64,9 @@
mlir::Type indexType = builder.getIndexType();
for (unsigned i = 0; i < numDims; ++i) {
procInfo[numDims - 1 - i] = {
- builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]),
- builder.create<mlir::arith::ConstantOp>(
- loc, builder.getIndexAttr(workgroupSize[i])),
+ mlir::gpu::ThreadIdOp::create(builder, loc, indexType, dimAttr[i]),
+ mlir::arith::ConstantOp::create(builder, loc,
+ builder.getIndexAttr(workgroupSize[i])),
linalg::DistributionMethod::Cyclic};
}
return procInfo;
@@ -83,7 +83,7 @@
mlir::Type indexType = builder.getIndexType();
for (unsigned i = 0; i < numDims; ++i) {
mlir::Value subgroupId =
- builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
+ mlir::gpu::ThreadIdOp::create(builder, loc, indexType, dimAttr[i]);
if (i == 0) {
subgroupId =
builder
@@ -93,8 +93,8 @@
}
procInfo[numDims - 1 - i] = {
subgroupId,
- builder.create<mlir::arith::ConstantOp>(
- loc, builder.getIndexAttr(numSubgroups[i])),
+ mlir::arith::ConstantOp::create(builder, loc,
+ builder.getIndexAttr(numSubgroups[i])),
linalg::DistributionMethod::Cyclic};
}
return procInfo;
@@ -250,7 +250,7 @@
shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
gpu::AddressSpaceAttr::get(builder.getContext(),
gpu::GPUDialect::getWorkgroupAddressSpace()));
- Value buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+ Value buffer = memref::AllocOp::create(builder, funcOp.getLoc(), type);
return buffer;
}
@@ -259,7 +259,7 @@
}
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
- Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
+ Operation *copyOp = memref::CopyOp::create(b, src.getLoc(), src, dst);
setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
return success();
}
@@ -304,8 +304,8 @@
SmallVector<utils::IteratorType> iterTypes(op.getNumLoops(),
utils::IteratorType::parallel);
OpBuilder builder(op);
- auto newOp = builder.create<linalg::GenericOp>(
- loc, newOperands, outOperand->get(), maps, iterTypes);
+ auto newOp = linalg::GenericOp::create(builder, loc, newOperands,
+ outOperand->get(), maps, iterTypes);
newOp.getRegion().getBlocks().splice(newOp.getRegion().begin(),
op.getRegion().getBlocks());
@@ -365,12 +365,12 @@
Operation *prevOp = copyOp->getPrevNode();
if (!prevOp || !hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) {
builder.setInsertionPoint(copyOp);
- builder.create<gpu::BarrierOp>(copyOp->getLoc());
+ gpu::BarrierOp::create(builder, copyOp->getLoc());
}
Operation *nextOp = copyOp->getNextNode();
if (!nextOp || !hasMarker(nextOp, getCopyToWorkgroupMemoryMarker())) {
builder.setInsertionPointAfter(copyOp);
- builder.create<gpu::BarrierOp>(copyOp->getLoc());
+ gpu::BarrierOp::create(builder, copyOp->getLoc());
}
}
});
@@ -386,7 +386,7 @@
Value input) {
VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType());
Value vectorInput =
- builder.create<vector::BroadcastOp>(loc, vectorTypeBroadcast, input);
+ vector::BroadcastOp::create(builder, loc, vectorTypeBroadcast, input);
return vectorInput;
}
@@ -403,8 +403,9 @@
});
VectorType packed32Type = VectorType::get({1}, builder.getI32Type());
Value packedInputVec =
- builder.create<vector::BitCastOp>(loc, packed32Type, input);
- Value packedInput = builder.create<vector::ExtractOp>(loc, packedInputVec, 0);
+ vector::BitCastOp::create(builder, loc, packed32Type, input);
+ Value packedInput =
+ vector::ExtractOp::create(builder, loc, packedInputVec, 0);
return packedInput;
}
@@ -420,7 +421,7 @@
});
Value packedVector = promoteElementToVector(loc, builder, packedInput);
Value unpackedVector =
- builder.create<vector::BitCastOp>(loc, targetVecType, packedVector);
+ vector::BitCastOp::create(builder, loc, targetVecType, packedVector);
return unpackedVector;
}
@@ -450,11 +451,11 @@
// SPIRV currently doesn't have a lowering for clustered reduction,
// so if possible avoid adding problematic attribute until it is supported.
if (numLaneToReduce == warpSize) {
- return builder.create<gpu::SubgroupReduceOp>(loc, input, gpuReduceKind,
- /*uniform=*/false);
+ return gpu::SubgroupReduceOp::create(builder, loc, input, gpuReduceKind,
+ /*uniform=*/false);
}
- return builder.create<gpu::SubgroupReduceOp>(
- loc, input, gpuReduceKind, /*uniform=*/false, numLaneToReduce);
+ return gpu::SubgroupReduceOp::create(builder, loc, input, gpuReduceKind,
+ /*uniform=*/false, numLaneToReduce);
}
// Otherwise, perform the shuffles over the supported scalar type. For inputs
@@ -464,8 +465,8 @@
origInputType](Value packedVal) -> Value {
if (!needsPacking)
return packedVal;
- auto asInt = builder.create<arith::TruncIOp>(loc, equivIntType, packedVal);
- return builder.create<arith::BitcastOp>(loc, origInputType, asInt);
+ auto asInt = arith::TruncIOp::create(builder, loc, equivIntType, packedVal);
+ return arith::BitcastOp::create(builder, loc, origInputType, asInt);
};
auto pack = [loc, &builder, needsPacking, equivIntType,
@@ -473,8 +474,8 @@
if (!needsPacking)
return unpackedVal;
auto asInt =
- builder.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
- return builder.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
+ arith::BitcastOp::create(builder, loc, equivIntType, unpackedVal);
+ return arith::ExtUIOp::create(builder, loc, shuffleIntType, asInt);
};
// Lane value always stays in the original type. We use it to perform arith
@@ -557,7 +558,7 @@
}
assert(identityAttr && "Unknown identity value for the reduction");
Value identity =
- builder.create<arith::ConstantOp>(loc, identityType, identityAttr);
+ arith::ConstantOp::create(builder, loc, identityType, identityAttr);
return identity;
}
@@ -597,7 +598,7 @@
"Group reduction only support for sizes aligned on warp size for now.");
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
laneVal = warpReduction(loc, builder, laneVal, kind, warpSize, warpSize,
expandSubgroupReduce);
// Simple case -- emit `gpu.subgroup_reduce` directly.
@@ -618,37 +619,38 @@
MemRefType memrefType =
MemRefType::get(numWarp, laneVal.getType(), MemRefLayoutAttrInterface{},
addressSpaceAttr);
- Value alloc = builder.create<memref::AllocOp>(loc, memrefType);
- Value threadX = builder.create<gpu::ThreadIdOp>(loc, builder.getIndexType(),
- gpu::Dimension::x);
- Value cstWarpSize = builder.create<arith::ConstantIndexOp>(loc, warpSize);
- Value warpId = builder.create<arith::DivUIOp>(loc, threadX, cstWarpSize);
- Value laneId = builder.create<arith::RemUIOp>(loc, threadX, cstWarpSize);
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value lane0 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- laneId, zero);
+ Value alloc = memref::AllocOp::create(builder, loc, memrefType);
+ Value threadX = gpu::ThreadIdOp::create(
+ builder, loc, builder.getIndexType(), gpu::Dimension::x);
+ Value cstWarpSize = arith::ConstantIndexOp::create(builder, loc, warpSize);
+ Value warpId = arith::DivUIOp::create(builder, loc, threadX, cstWarpSize);
+ Value laneId = arith::RemUIOp::create(builder, loc, threadX, cstWarpSize);
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value lane0 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
+ laneId, zero);
// Store the reduction for each warp.
SmallVector<Value> indices = {warpId};
- builder.create<scf::IfOp>(loc, lane0, [&](OpBuilder &b, Location l) {
- b.create<memref::StoreOp>(l, laneVal, alloc, indices);
- b.create<scf::YieldOp>(l);
+ scf::IfOp::create(builder, loc, lane0, [&](OpBuilder &b, Location l) {
+ memref::StoreOp::create(b, l, laneVal, alloc, indices);
+ scf::YieldOp::create(b, l);
});
- builder.create<gpu::BarrierOp>(loc);
+ gpu::BarrierOp::create(builder, loc);
// Further reduce the outputs from each warps with a single warp reduce.
- Value memrefSize = builder.create<arith::ConstantIndexOp>(loc, numWarp - 1);
+ Value memrefSize =
+ arith::ConstantIndexOp::create(builder, loc, numWarp - 1);
Value laneIdInBounds =
- builder.create<arith::MinUIOp>(loc, laneId, memrefSize);
- Value loadVal = builder.create<memref::LoadOp>(loc, alloc, laneIdInBounds);
- Value cstNumWarp = builder.create<arith::ConstantIndexOp>(loc, numWarp);
+ arith::MinUIOp::create(builder, loc, laneId, memrefSize);
+ Value loadVal = memref::LoadOp::create(builder, loc, alloc, laneIdInBounds);
+ Value cstNumWarp = arith::ConstantIndexOp::create(builder, loc, numWarp);
if (!llvm::isPowerOf2_32(numWarp)) {
// Pad with identity element if numel < warpSize for valid warp reduction.
- Value useIdentityElement = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, laneId, cstNumWarp);
+ Value useIdentityElement = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sge, laneId, cstNumWarp);
numWarp = llvm::PowerOf2Ceil(numWarp);
Value identity =
getCombiningIdentityValue(loc, builder, kind, loadVal.getType());
- loadVal = builder.create<arith::SelectOp>(loc, useIdentityElement,
- identity, loadVal);
+ loadVal = arith::SelectOp::create(builder, loc, useIdentityElement,
+ identity, loadVal);
}
laneVal = warpReduction(loc, builder, loadVal, kind, warpSize, numWarp,
/*expandSubgroupReduce=*/true);
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 013deaa..7c39950 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -448,8 +448,8 @@
FunctionType::get(context, operandTypes, customOp->getResultTypes());
std::string dummyFuncName =
std::string("__") + funcOp.getName().str() + "_config_setting__";
- auto dummyFuncOp = rewriter.create<func::FuncOp>(
- customOp.getLoc(), dummyFuncName, dummyFuncType);
+ auto dummyFuncOp = func::FuncOp::create(rewriter, customOp.getLoc(),
+ dummyFuncName, dummyFuncType);
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
if (targetAttr) {
dummyFuncOp->setAttr(IREE::HAL::ExecutableTargetAttr::name, targetAttr);
@@ -475,8 +475,8 @@
}
auto clonedCustomOp = cast<IREE::LinalgExt::CustomOp>(
rewriter.clone(*customOp.getOperation(), map));
- rewriter.create<func::ReturnOp>(customOp.getLoc(),
- clonedCustomOp->getResults());
+ func::ReturnOp::create(rewriter, customOp.getLoc(),
+ clonedCustomOp->getResults());
CustomOpConfigListener customOpConfigListener(customOp, clonedCustomOp);
// 4. Inline the cloned custom op.
@@ -963,10 +963,10 @@
llvm::zip(fixedTileSizes, fixedTileScalableFlags),
[&](auto pair) -> OpFoldResult {
auto [t, isScalable] = pair;
- Value size = b.create<arith::ConstantIndexOp>(loc, t);
+ Value size = arith::ConstantIndexOp::create(b, loc, t);
if (isScalable) {
- Value vscale = b.create<vector::VectorScaleOp>(loc);
- size = b.create<arith::MulIOp>(loc, size, vscale);
+ Value vscale = vector::VectorScaleOp::create(b, loc);
+ size = arith::MulIOp::create(b, loc, size, vscale);
}
return size;
});
@@ -992,14 +992,14 @@
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
utils::IteratorType::parallel);
- return b.create<linalg::GenericOp>(
- loc,
+ return linalg::GenericOp::create(
+ b, loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args.front());
+ linalg::YieldOp::create(b, loc, args.front());
},
attributes);
}
@@ -1048,8 +1048,8 @@
}
if (splitDim) {
std::reverse(splitNumTiles.begin(), splitNumTiles.end());
- auto delinearized = builder.create<affine::AffineDelinearizeIndexOp>(
- loc, *splitDim, splitNumTiles, /*hasOuterBound=*/true);
+ auto delinearized = affine::AffineDelinearizeIndexOp::create(
+ builder, loc, *splitDim, splitNumTiles, /*hasOuterBound=*/true);
for (auto [i, id, numTiles] :
llvm::enumerate(delinearized.getResults(), splitNumTiles)) {
// We iterate the delinearize results from slowest up to fastest, and
@@ -1225,14 +1225,14 @@
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subspanOp);
// Just change the result type of the InterfaceBindingSubspanOp.
- Value buffer = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
- subspanOp->getLoc(), memRefType, subspanOp.getLayout(),
+ Value buffer = IREE::HAL::InterfaceBindingSubspanOp::create(
+ rewriter, subspanOp->getLoc(), memRefType, subspanOp.getLayout(),
subspanOp.getBinding(), subspanOp.getByteOffset(),
subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
subspanOp.getDescriptorFlagsAttr());
if (useRocdlBuffers) {
- buffer = rewriter.create<amdgpu::FatRawBufferCastOp>(
- subspanOp->getLoc(), buffer, /*validBytes=*/Value{},
+ buffer = amdgpu::FatRawBufferCastOp::create(
+ rewriter, subspanOp->getLoc(), buffer, /*validBytes=*/Value{},
/*cacheSwizzleStride=*/Value{}, /*boundsCheck=*/true,
/*resetOffset=*/true);
}
@@ -1333,7 +1333,7 @@
{byteOffset, rewriter.getIndexAttr(typeBitWidth)});
} else {
OpFoldResult elementByteSize =
- rewriter.create<IREE::Util::SizeOfOp>(loc, elementType).getResult();
+ IREE::Util::SizeOfOp::create(rewriter, loc, elementType).getResult();
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1),
@@ -1381,7 +1381,7 @@
currentResultType.getShape(), currentResultType.getElementType(),
replacementType.getLayout(), replacementType.getMemorySpace());
auto newCastOp =
- rewriter.create<memref::CastOp>(loc, newResultType, replacement);
+ memref::CastOp::create(rewriter, loc, newResultType, replacement);
LDBG() << "\t\tNew user : " << *newCastOp;
return SmallVector<Value>(newCastOp->result_begin(),
newCastOp->result_end());
@@ -1400,8 +1400,8 @@
currResultType.getShape(), newSourceType, offsets, sizes,
strides))
: nullptr);
- auto newSubviewOp = rewriter.create<memref::SubViewOp>(
- loc, newResultType, replacement, offsets, sizes, strides);
+ auto newSubviewOp = memref::SubViewOp::create(
+ rewriter, loc, newResultType, replacement, offsets, sizes, strides);
LDBG() << "\t\tNew user : " << *newSubviewOp;
return llvm::to_vector_of<Value>(newSubviewOp->getResults());
@@ -1419,8 +1419,8 @@
return std::nullopt;
}
- auto newExpandOp = rewriter.create<memref::ExpandShapeOp>(
- loc, *newResultType, replacement, expandOp.getReassociation(),
+ auto newExpandOp = memref::ExpandShapeOp::create(
+ rewriter, loc, *newResultType, replacement, expandOp.getReassociation(),
expandOp.getOutputShape(), expandOp.getStaticOutputShape());
LDBG() << "\t\tNew user : " << *newExpandOp;
return llvm::to_vector_of<Value>(newExpandOp->getResults());
@@ -1434,8 +1434,9 @@
return std::nullopt;
}
- auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
- loc, *newResultType, replacement, collapseOp.getReassociation());
+ auto newCollapseOp = memref::CollapseShapeOp::create(
+ rewriter, loc, *newResultType, replacement,
+ collapseOp.getReassociation());
LDBG() << "\t\tNew user : " << *newCollapseOp;
return llvm::to_vector_of<Value>(newCollapseOp->getResults());
}