NFC: Remove crufty insertion point search. (#6937)
To start creating the memrefs for the output (when bufferizing a
`flow.tensor.store`), the result buffer has to be allocated before any
uses of it. Replace the crufty insertion point search with just
cloning of operations. They get CSE-ed anyway.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index efff79d..e5b4f34 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -57,6 +57,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:GPUDialect",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index 686fba3..4e78086 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -38,6 +38,7 @@
LLVMSupport
MLIRAffine
MLIRAffineUtils
+ MLIRAnalysis
MLIRGPUOps
MLIRIR
MLIRLLVMCommonConversion
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index f440db2..eb4a79f 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -49,7 +49,9 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -789,31 +791,6 @@
//
// ===----------------------------------------------------------------------===//
-/// For a given store-like `op` that is to be replaced, find the insertion point
-/// in the same block earliest possible when
-/// - the replacement op uses values in `usedValues`, so has to be inserted
-/// after the ops that define these.
-/// - The op needs to be inserted before `insertBefore` (which is in the same
-/// block). Return nullptr all other times.
-static Operation *getInsertionPointForReplacementStoreOp(
- Operation *op, Operation *insertBefore, ArrayRef<Value> usedValues) {
- if (op->getBlock() != insertBefore->getBlock()) return nullptr;
- Operation *insertAfter = nullptr;
- for (auto value : usedValues) {
- Operation *definingOp = value.getDefiningOp();
- if (!definingOp || definingOp->getBlock() != insertBefore->getBlock())
- continue;
- if (!insertAfter || insertAfter->isBeforeInBlock(definingOp))
- insertAfter = definingOp;
- }
- // All defining ops are outside of the block, so just insert at the start of
- // the block.
- if (!insertAfter) return &(op->getBlock()->front());
- if (insertAfter->isBeforeInBlock(insertBefore))
- return insertAfter->getNextNode();
- return nullptr;
-}
-
/// Returns the subview into the buffer that is supposed to be populated with
/// the `value` of the `flow.dispatch.tensor.store` operation. This can be used
/// to compute the results in place.
@@ -828,18 +805,54 @@
[&](auto storeOp) { return storeOp.dest(); })
.Default([](Operation *) { return nullptr; });
if (!target) return nullptr;
- operandsOfSubviewOp.push_back(bvm.lookup(target));
- operandsOfSubviewOp.append(op.offsets().begin(), op.offsets().end());
- operandsOfSubviewOp.append(op.sizes().begin(), op.sizes().end());
- operandsOfSubviewOp.append(op.strides().begin(), op.strides().end());
- Operation *insertBefore = &(*b.getInsertionPoint());
- Operation *insertionPoint = getInsertionPointForReplacementStoreOp(
- op.getOperation(), insertBefore, operandsOfSubviewOp);
- if (!insertionPoint) return nullptr;
- OpBuilder::InsertionGuard g(b);
- Value subview =
- createSubviewOp(b, op.getLoc(), bvm.lookup(target), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides());
+
+ // Clone the offset, size and stride values. They will be CSE-ed later.
+ Operation *parentOp = storeOp->getParentOp();
+ BlockAndValueMapping indexValMap;
+ llvm::SetVector<Operation *> slice;
+ auto cloneIndexValues = [&](ArrayRef<OpFoldResult> ofrs) {
+ SmallVector<OpFoldResult> clonedVals;
+ for (auto ofr : ofrs) {
+ // Just copy the attributes.
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ clonedVals.push_back(attr);
+ continue;
+ }
+ Value val = ofr.get<Value>();
+ // If it is a block argument use the same value.
+ if (val.isa<BlockArgument>()) {
+ clonedVals.push_back(val);
+ continue;
+ }
+ // The slice of ops needed for index computation need to be cloned to
+ // avoid use-def violations. If the value has been cloned already, reuse
+ // that.
+ if (auto lookupVal = indexValMap.lookupOrNull(val)) {
+ clonedVals.push_back(lookupVal);
+ continue;
+ }
+ slice.clear();
+ getBackwardSlice(val, &slice, [&](Operation *sliceOp) {
+ return sliceOp->getParentOp() == parentOp;
+ });
+ for (auto sliceOp : slice) {
+ if (!indexValMap.contains(sliceOp->getResult(0))) {
+ b.clone(*sliceOp, indexValMap);
+ }
+ }
+ if (Operation *definingOp = val.getDefiningOp()) {
+ b.clone(*definingOp, indexValMap);
+ }
+ clonedVals.push_back(indexValMap.lookup(val));
+ }
+ return clonedVals;
+ };
+ SmallVector<OpFoldResult> subViewOffsets, subViewSizes, subViewStrides;
+ subViewOffsets = cloneIndexValues(op.getMixedOffsets());
+ subViewSizes = cloneIndexValues(op.getMixedSizes());
+ subViewStrides = cloneIndexValues(op.getMixedStrides());
+ Value subview = createSubviewOp(b, op.getLoc(), bvm.lookup(target),
+ subViewOffsets, subViewSizes, subViewStrides);
return subview;
}