Merge pull request #8837 from matthias-springer/fix_bufferization_inparallel2

Fix bufferization of in_parallel
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
index 9b3c2fb..ceb5de9 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
@@ -78,10 +78,6 @@
                           BufferizationState &state) const {
     OpBuilder::InsertionGuard g(b);
     auto inParallelOp = cast<InParallelOp>(op);
-    Block *body = &inParallelOp.region().front();
-    Operation *oldTerminator = body->getTerminator();
-    assert(isa<PerformConcurrentlyOp>(oldTerminator) &&
-           "unexpected terminator");
 
     // Gather new results of the InParallelOp.
     SmallVector<Value> newResults;
@@ -93,22 +89,10 @@
       // Insert copies right before the PerformConcurrentlyOp terminator. They
       // should not be inside terminator (which would be the default insertion
       // point).
-      Value buffer = *state.getBuffer(
-          b, *insertDestOperands.front(), /*forceInPlace=*/false,
-          /*customCopyInsertionPoint=*/oldTerminator);
+      Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
+                                      /*forceInPlace=*/false,
+                                      /*customCopyInsertionPoint=*/op);
       newResults.push_back(buffer);
-      Value destTensor = insertDestOperands.front()->get();
-
-      // Replace all uses of the insert dest tensor inside the InParallelOp
-      // with the result buffer.
-      OpBuilder::InsertionGuard g(b);
-      b.setInsertionPointToStart(body);
-      Value toTensorOp =
-          b.create<bufferization::ToTensorOp>(inParallelOp.getLoc(), buffer);
-      for (OpOperand &use : destTensor.getUses())
-        if (body->findAncestorOpInBlock(*use.getOwner()))
-          // This is a use inside the InParallelOp.
-          use.set(toTensorOp);
     }
 
     // Create new InParallelOp without any results.
@@ -127,20 +111,17 @@
     auto performConcurrentlyOp =
         cast<PerformConcurrentlyOp>(newInParallelOp.getBody()->getTerminator());
     b.setInsertionPoint(performConcurrentlyOp);
+    unsigned resultCounter = 0;
     WalkResult walkResult =
         performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
           Location loc = insertOp.getLoc();
           Type srcType = getMemRefType(
               insertOp.source().getType().cast<RankedTensorType>(),
               state.getOptions());
-          Type destType =
-              getMemRefType(insertOp.dest().getType().cast<RankedTensorType>(),
-                            state.getOptions());
           // ParallelInsertSliceOp bufferizes to a copy.
           auto srcMemref = b.create<bufferization::ToMemrefOp>(
               loc, srcType, insertOp.source());
-          auto destMemref = b.create<bufferization::ToMemrefOp>(
-              loc, destType, insertOp.dest());
+          Value destMemref = newResults[resultCounter++];
           Value subview = b.create<memref::SubViewOp>(
               loc, destMemref, insertOp.getMixedOffsets(),
               insertOp.getMixedSizes(), insertOp.getMixedStrides());
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
index 4e5043c..3b685a3 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
@@ -54,6 +54,8 @@
 
   // The parallel_insert_slice_op bufferizes out-of-place, so we need an allocation.
   // CHECK: %[[alloc1:.*]] = memref.alloc
+  // CHECK: linalg.generic {{.*}} ins(%[[arg2]]{{.*}}outs(%[[alloc1]]
+
   // CHECK: iree_linalg_ext.in_parallel %[[idx2]]  -> ()
   %2 = iree_linalg_ext.in_parallel %idx2  -> (tensor<?xf32>) {
     ^bb0(%arg3: index):  // no predecessors
@@ -64,12 +66,8 @@
       // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
       %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
 
-      // parallel_insert_slice buffer was already allocated but not copied yet.
-      //
-      // CHECK: linalg.generic {{.*}} ins(%[[arg2]]{{.*}}outs(%[[alloc1]]
-
       // Now the copy of the actual insert_slice.
-      // CHECK: %[[subview1:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
+      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
       //
       // CHECK: linalg.generic {{.*}} ins(%[[alloc2]]{{.*}}outs(%[[subview1]]
       // CHECK: memref.dealloc %[[alloc2]]