Merge pull request #8837 from matthias-springer/fix_bufferization_inparallel2
Fix bufferization of in_parallel
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
index 9b3c2fb..ceb5de9 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
@@ -78,10 +78,6 @@
BufferizationState &state) const {
OpBuilder::InsertionGuard g(b);
auto inParallelOp = cast<InParallelOp>(op);
- Block *body = &inParallelOp.region().front();
- Operation *oldTerminator = body->getTerminator();
- assert(isa<PerformConcurrentlyOp>(oldTerminator) &&
- "unexpected terminator");
// Gather new results of the InParallelOp.
SmallVector<Value> newResults;
@@ -93,22 +89,10 @@
// Insert copies right before the PerformConcurrentlyOp terminator. They
// should not be inside terminator (which would be the default insertion
// point).
- Value buffer = *state.getBuffer(
- b, *insertDestOperands.front(), /*forceInPlace=*/false,
- /*customCopyInsertionPoint=*/oldTerminator);
+ Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
+ /*forceInPlace=*/false,
+ /*customCopyInsertionPoint=*/op);
newResults.push_back(buffer);
- Value destTensor = insertDestOperands.front()->get();
-
- // Replace all uses of the insert dest tensor inside the InParallelOp
- // with the result buffer.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToStart(body);
- Value toTensorOp =
- b.create<bufferization::ToTensorOp>(inParallelOp.getLoc(), buffer);
- for (OpOperand &use : destTensor.getUses())
- if (body->findAncestorOpInBlock(*use.getOwner()))
- // This is a use inside the InParallelOp.
- use.set(toTensorOp);
}
// Create new InParallelOp without any results.
@@ -127,20 +111,17 @@
auto performConcurrentlyOp =
cast<PerformConcurrentlyOp>(newInParallelOp.getBody()->getTerminator());
b.setInsertionPoint(performConcurrentlyOp);
+ unsigned resultCounter = 0;
WalkResult walkResult =
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
Location loc = insertOp.getLoc();
Type srcType = getMemRefType(
insertOp.source().getType().cast<RankedTensorType>(),
state.getOptions());
- Type destType =
- getMemRefType(insertOp.dest().getType().cast<RankedTensorType>(),
- state.getOptions());
// ParallelInsertSliceOp bufferizes to a copy.
auto srcMemref = b.create<bufferization::ToMemrefOp>(
loc, srcType, insertOp.source());
- auto destMemref = b.create<bufferization::ToMemrefOp>(
- loc, destType, insertOp.dest());
+ Value destMemref = newResults[resultCounter++];
Value subview = b.create<memref::SubViewOp>(
loc, destMemref, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
index 4e5043c..3b685a3 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
@@ -54,6 +54,8 @@
// The parallel_insert_slice_op bufferizes out-of-place, so we need an allocation.
// CHECK: %[[alloc1:.*]] = memref.alloc
+ // CHECK: linalg.generic {{.*}} ins(%[[arg2]]{{.*}}outs(%[[alloc1]]
+
// CHECK: iree_linalg_ext.in_parallel %[[idx2]] -> ()
%2 = iree_linalg_ext.in_parallel %idx2 -> (tensor<?xf32>) {
^bb0(%arg3: index): // no predecessors
@@ -64,12 +66,8 @@
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
%8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
- // parallel_insert_slice buffer was already allocated but not copied yet.
- //
- // CHECK: linalg.generic {{.*}} ins(%[[arg2]]{{.*}}outs(%[[alloc1]]
-
// Now the copy of the actual insert_slice.
- // CHECK: %[[subview1:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
+ // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
//
// CHECK: linalg.generic {{.*}} ins(%[[alloc2]]{{.*}}outs(%[[subview1]]
// CHECK: memref.dealloc %[[alloc2]]