Return error instead of asserting in bufferization copy callback. (#8528)
Current callback just asserts that the copy callback can insert a
valid copy. Return a nullptr, and throw an error when copy cannot be
generated. Also update all call sites to handle the error.
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 2278c81..69ad361 100644
--- a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -126,8 +126,8 @@
}
static LogicalResult defaultMemCpyFn(OpBuilder &builder, Location loc,
Value from, Value to) {
- createLinalgCopyOp(builder, loc, from, to);
- return success();
+ Operation *copyOp = createLinalgCopyOp(builder, loc, from, to);
+ return success(static_cast<bool>(copyOp));
}
std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 391ee1a..2a82d82 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -714,7 +714,9 @@
if (outBuffer && !plan.isEquivalent(outTensor, resultTensor)) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || linalgOp.payloadUsesValueFromOperand(outOperand)) {
- createLinalgCopyOp(b, loc, outBuffer, resultBuffer);
+ if (!createLinalgCopyOp(b, loc, outBuffer, resultBuffer)) {
+ return failure();
+ }
}
}
newOutputBuffers.push_back(resultBuffer);
@@ -778,8 +780,9 @@
b, storeOp.getLoc(), storeFrom.getType().cast<ShapedType>().getRank(),
storeTo, storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
storeOp.getMixedStrides());
- createLinalgCopyOp(b, storeOp->getLoc(), storeFrom, subview);
- return success();
+ Operation *copyOp =
+ createLinalgCopyOp(b, storeOp->getLoc(), storeFrom, subview);
+ return success(static_cast<bool>(copyOp));
}
/// Converts a `tensor.insert_slice` operation to buffers by
@@ -798,7 +801,9 @@
Value dest = op.dest();
if (!plan.isEquivalent(dest, result)) {
Value destBuffer = bvm.lookup(dest);
- createLinalgCopyOp(b, loc, destBuffer, resultBuffer);
+ if (!createLinalgCopyOp(b, loc, destBuffer, resultBuffer)) {
+ return failure();
+ }
}
Value source = op.source();
@@ -814,8 +819,8 @@
SmallVector<OpFoldResult> strides = op.getMixedStrides();
Value subViewOp = createSubviewOp(b, loc, sourceType.getRank(), resultBuffer,
offsets, sizes, strides);
- createLinalgCopyOp(b, loc, sourceBuffer, subViewOp);
- return success();
+ Operation *copyOp = createLinalgCopyOp(b, loc, sourceBuffer, subViewOp);
+ return success(static_cast<bool>(copyOp));
}
/// Converts a `tensor.insert` operations into a `memref.store`.
@@ -827,7 +832,9 @@
Value resultBuffer = bvm.lookup(result);
if (!plan.isEquivalent(op.dest(), result)) {
Value destBuffer = bvm.lookup(op.dest());
- createLinalgCopyOp(b, loc, destBuffer, resultBuffer);
+ if (!createLinalgCopyOp(b, loc, destBuffer, resultBuffer)) {
+ return failure();
+ }
}
b.create<memref::StoreOp>(loc, op.scalar(), resultBuffer, op.indices());
@@ -850,7 +857,9 @@
// initial value and can avoid the copy.
!op.source().getDefiningOp<linalg::InitTensorOp>()) {
Value destBuffer = bvm.lookup(op.source());
- createLinalgCopyOp(b, loc, destBuffer, resultBuffer);
+ if (!createLinalgCopyOp(b, loc, destBuffer, resultBuffer)) {
+ return failure();
+ }
}
// Create a new vector.transfer_write operation without a result value.
@@ -875,7 +884,9 @@
bvm.map(yieldOperand, resultBuffer);
if (!plan.isEquivalent(arg.value(), initOperand.get())) {
Value initBuffer = bvm.lookup(initOperand.get());
- createLinalgCopyOp(b, loc, initBuffer, resultBuffer);
+ if (!createLinalgCopyOp(b, loc, initBuffer, resultBuffer)) {
+ return failure();
+ }
}
}
return success();
@@ -901,17 +912,20 @@
/// If the alias of the buffer for an input oeprand cannot be used for the
/// "tied" results, need to do an explicit copy of the memory pointed to by the
/// aliased buffer into the buffer assigned to the result.
-static void copyFromAliasingBufferToResultBuffer(
+static LogicalResult copyFromAliasingBufferToResultBuffer(
OpBuilder &b, Location loc, ArrayRef<Value> tiedOperands,
ArrayRef<Value> tiedResults, ArrayRef<Value> aliasingBuffers,
BlockAndValueMapping &bvm, BufferizationPlan &plan) {
for (auto result : enumerate(tiedResults)) {
Value operand = tiedOperands[result.index()];
if (!plan.isEquivalent(result.value(), operand)) {
- createLinalgCopyOp(b, loc, aliasingBuffers[result.index()],
- bvm.lookup(result.value()));
+ if (!createLinalgCopyOp(b, loc, aliasingBuffers[result.index()],
+ bvm.lookup(result.value()))) {
+ return failure();
+ }
}
}
+ return success();
}
/// Returns the static/dynamic mixed sizes of the memref.
@@ -967,8 +981,8 @@
tensorPadOp.getMixedLowPad(),
sizeMixedValues, strides);
// Copy to the interior region.
- createLinalgCopyOp(b, loc, inputMemref, resultSubView);
- return success();
+ Operation *copyOp = createLinalgCopyOp(b, loc, inputMemref, resultSubView);
+ return success(static_cast<bool>(copyOp));
}
namespace {
@@ -1060,10 +1074,9 @@
allocationFn))) {
return failure();
}
- copyFromAliasingBufferToResultBuffer(
+ return copyFromAliasingBufferToResultBuffer(
b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
aliasingOp->getResult(0), aliasingBuffers, bvm, plan);
- return success();
})
.Case<tensor::PadOp>([&](tensor::PadOp tensorPadOp) {
if (failed(getOrAllocateResultBuffers(b, tensorPadOp, bvm, plan,
diff --git a/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp b/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp
index 84b7491..e2b9806 100644
--- a/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp
+++ b/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp
@@ -24,6 +24,7 @@
Operation *linalgCopy =
createLinalgCopyOp(rewriter, copyOp.getLoc(), copyOp.source(),
copyOp.target(), copyOp->getAttrs());
+ if (!linalgCopy) return failure();
rewriter.replaceOp(copyOp, linalgCopy->getResults());
return success();
}
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 0dec12c..4651efc 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -492,11 +492,15 @@
/// memref::CopyOp.
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, Value to,
ArrayRef<NamedAttribute> attributes) {
- auto memrefTypeFrom = from.getType().cast<MemRefType>();
- auto memrefTypeTo = to.getType().cast<MemRefType>();
- (void)memrefTypeFrom;
- assert(memrefTypeFrom && memrefTypeTo &&
- memrefTypeFrom.getRank() == memrefTypeTo.getRank());
+ auto memrefTypeFrom = from.getType().dyn_cast<MemRefType>();
+ auto memrefTypeTo = to.getType().dyn_cast<MemRefType>();
+ if (!memrefTypeFrom || !memrefTypeTo ||
+ memrefTypeFrom.getRank() != memrefTypeTo.getRank()) {
+ mlir::emitError(
+ loc, "unable to generate copy op within bufferization from type ")
+ << memrefTypeFrom << " to " << memrefTypeTo;
+ return nullptr;
+ }
AffineMap id =
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),