Fix incorrect API usage in RewritePatterns (round 1) (#12466)
Incorrect API usage was detected by D144552.
diff --git a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
index eaeea88..4b83df5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
@@ -86,7 +86,8 @@
rewriter.mergeBlocks(source, dest, dest->getArguments());
// Replace the yield op by one that returns only the used values.
auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
- yieldOp.getOperation()->setOperands(results);
+ rewriter.updateRootInPlace(
+ yieldOp, [&]() { yieldOp.getOperation()->setOperands(results); });
}
LogicalResult matchAndRewrite(scf::ForOp forOp,
@@ -130,8 +131,9 @@
rewriter.clone(*resultOps[index], mapping)->getResult(0);
Operation* oldOp =
newLoop.getRegionIterArgs()[index].use_begin()->getOwner();
- SmallVector<Value, 1> arg(1, newLoop.getRegionIterArgs()[index]);
- oldOp->replaceAllUsesWith(arg);
+ assert(oldOp->getNumResults() == 1 && "expected single result");
+ rewriter.replaceAllUsesWith(oldOp->getResult(0),
+ newLoop.getRegionIterArgs()[index]);
}
rewriter.replaceOp(forOp, repResults);
return success();
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp
index 1d6cb26..7900c10 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp
@@ -162,7 +162,7 @@
if (!cst) return failure();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(minOp,
rewriter.getIndexAttr(*cst));
- return failure();
+ return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index c5b22c9..1abd377 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -186,7 +186,7 @@
auto dynamicDims = op.getResultDynamicDims(result.getResultNumber());
auto emptyOp = rewriter.create<IREE::Flow::TensorEmptyOp>(
result.getLoc(), result.getType(), dynamicDims);
- result.replaceAllUsesWith(emptyOp);
+ rewriter.replaceAllUsesWith(result, emptyOp);
didReplaceAny = true;
}
}
@@ -226,7 +226,8 @@
// The dimension values may be derived values that are redundant with captured
// dimensions and by redirecting to the captured values we can simplify things.
// Returns true if the dims were changed.
-static bool updateTensorOpDims(Operation *op, Value tensorValue,
+static bool updateTensorOpDims(RewriterBase &rewriter, Operation *op,
+ Value tensorValue,
MutableOperandRange mutableDimValues) {
auto dynamicDimsOr = IREE::Util::findDynamicDims(tensorValue, op->getBlock(),
Block::iterator(op));
@@ -237,7 +238,8 @@
auto oldValues = llvm::to_vector<4>(oldValueRange);
for (unsigned i = 0; i < dynamicDims.size(); ++i) {
if (oldValues[i] != dynamicDims[i]) {
- mutableDimValues.slice(i, 1).assign(dynamicDims[i]);
+ rewriter.updateRootInPlace(
+ op, [&]() { mutableDimValues.slice(i, 1).assign(dynamicDims[i]); });
anyChanged = true;
}
}
@@ -249,7 +251,7 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
- return success(updateTensorOpDims(loadOp, loadOp.getSource(),
+ return success(updateTensorOpDims(rewriter, loadOp, loadOp.getSource(),
loadOp.getSourceDimsMutable()));
}
};
@@ -400,7 +402,7 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTensorStoreOp storeOp,
PatternRewriter &rewriter) const override {
- return success(updateTensorOpDims(storeOp, storeOp.getTarget(),
+ return success(updateTensorOpDims(rewriter, storeOp, storeOp.getTarget(),
storeOp.getTargetDimsMutable()));
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
index 78e42a5..da53892 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
@@ -78,10 +78,10 @@
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
if (auto sizeOp =
dyn_cast<IREE::Stream::ResourceSizeOp>(use.getOwner())) {
- sizeOp.getResult().replaceAllUsesWith(sizeValue);
- rewriter.eraseOp(sizeOp);
+ rewriter.replaceOp(sizeOp, sizeValue);
} else {
- use.set(resourceValue);
+ rewriter.updateRootInPlace(use.getOwner(),
+ [&]() { use.set(resourceValue); });
}
}
rewriter.eraseOp(castOp);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 2f87069..1aef723 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -259,7 +259,9 @@
IREE::Util::TiedOpInterface::findTiedBaseValue(result.value());
if (auto blockArg = baseValue.template dyn_cast<BlockArgument>()) {
unsigned operandIndex = blockArg.getArgNumber();
- op.setTiedResultOperandIndex(result.index(), operandIndex);
+ rewriter.updateRootInPlace(op, [&]() {
+ op.setTiedResultOperandIndex(result.index(), operandIndex);
+ });
didModify = true;
}
}
@@ -603,9 +605,7 @@
for (auto sliceOffset : op.getPackedOffsets()) {
auto addOp =
rewriter.create<arith::AddIOp>(op.getLoc(), baseOffset, sliceOffset);
- SmallPtrSet<Operation *, 1> exclusions;
- exclusions.insert(addOp);
- sliceOffset.replaceAllUsesExcept(addOp.getResult(), exclusions);
+ rewriter.replaceAllUsesExcept(sliceOffset, addOp.getResult(), addOp);
}
return success();
@@ -665,9 +665,10 @@
dynamicSliceSizes, op.getAffinityAttr());
// Remap existing values to the new values.
- op.getTotalLength().replaceAllUsesWith(newOp.getTotalLength());
+ rewriter.replaceAllUsesWith(op.getTotalLength(), newOp.getTotalLength());
for (size_t i = 0; i < newOp.getPackedOffsets().size(); ++i) {
- slices[i].packedOffset.replaceAllUsesWith(newOp.getPackedOffsets()[i]);
+ rewriter.replaceAllUsesWith(slices[i].packedOffset,
+ newOp.getPackedOffsets()[i]);
}
rewriter.eraseOp(op);
@@ -1630,8 +1631,7 @@
if (!bitcastOp) return failure();
rewriter.updateRootInPlace(
loadOp, [&]() { loadedValue.setType(bitcastOp.getType()); });
- bitcastOp.getResult().replaceAllUsesWith(loadedValue);
- rewriter.eraseOp(bitcastOp);
+ rewriter.replaceOp(bitcastOp, loadedValue);
return success();
}
};
@@ -1737,7 +1737,7 @@
capture.subviewOp.getLoc(), arg, capture.subviewOp.getSourceSize(),
capture.subviewOp.getSourceOffset(),
capture.subviewOp.getResultSize());
- arg.replaceAllUsesExcept(newOp.getResult(), newOp);
+ rewriter.replaceAllUsesExcept(arg, newOp.getResult(), newOp);
}
rewriter.finalizeRootUpdate(op);
@@ -2151,7 +2151,7 @@
capture.subviewOp.getLoc(), arg, capture.subviewOp.getSourceSize(),
capture.subviewOp.getSourceOffset(),
capture.subviewOp.getResultSize());
- arg.replaceAllUsesExcept(newOp.getResult(), newOp);
+ rewriter.replaceAllUsesExcept(arg, newOp.getResult(), newOp);
}
rewriter.finalizeRootUpdate(op);
@@ -2530,7 +2530,7 @@
auto newOp = rewriter.create<IREE::Stream::ResourceSubviewOp>(
subviewOp.getLoc(), result, subviewOp.getSourceSize(),
subviewOp.getSourceOffset(), subviewOp.getResultSize());
- result.replaceAllUsesExcept(newOp.getResult(), newOp);
+ rewriter.replaceAllUsesExcept(result, newOp.getResult(), newOp);
// Update our bound size to the subview source size (not the subrange).
op.getResourceOperandSizesMutable()
@@ -2631,7 +2631,7 @@
unsigned resultIdx = 0;
for (auto coveredOp : coveredOps) {
for (auto result : coveredOp.getResults()) {
- result.replaceAllUsesWith(newOp.getResults()[resultIdx++]);
+ rewriter.replaceAllUsesWith(result, newOp.getResults()[resultIdx++]);
}
rewriter.eraseOp(coveredOp);
}
@@ -2681,7 +2681,7 @@
for (auto &replacement : replacements) {
auto oldResult = replacement.first;
auto newResult = newOp.getResults()[replacement.second];
- oldResult.replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(oldResult, newResult);
}
rewriter.eraseOp(op);
return success();
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index cdf9507..c378e35 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -207,7 +207,7 @@
// Replace all of the uses inside of the closure.
BlockArgument blockArg = entryBlock.getArgument(opArg.index());
- blockArg.replaceAllUsesWith(newValue);
+ rewriter.replaceAllUsesWith(blockArg, newValue);
}
}
}
@@ -282,8 +282,8 @@
// Dropped.
} else {
// Replaced.
- entryBlock.getArgument(replacement.index())
- .replaceAllUsesWith(*replacement.value());
+ rewriter.replaceAllUsesWith(entryBlock.getArgument(replacement.index()),
+ *replacement.value());
}
}
@@ -305,13 +305,13 @@
assert(oldResults.size() == newResults.size() &&
"expected non-closure results to match");
for (auto [oldResult, newResult] : llvm::zip_equal(oldResults, newResults)) {
- oldResult.replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(oldResult, newResult);
}
// Replace original uses of the closure results.
for (auto [oldResult, newResult] :
llvm::zip_equal(preservedResults, newOp.getClosureResults())) {
- oldResult.replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(oldResult, newResult);
}
// Erase the original op.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
index 29f94fa..ce480d2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
@@ -260,7 +260,8 @@
if (!uniformValue.getDefiningOp() ||
dominance.dominates(uniformValue.getDefiningOp()->getBlock(),
&block)) {
- block.getArgument(argIndex).replaceAllUsesWith(uniformValue);
+ rewriter.replaceAllUsesWith(block.getArgument(argIndex),
+ uniformValue);
elidedArgs.set(argIndex);
}
}