[Codegen] Add an op for fusing forall ops (#17279)

This fuses a single result, single use producer scf.forall op with
a specified consumer, inserting an `iree_gpu.tensor_shuffle` at the
boundary and remapping the thread indices of the producer to the basis
of the consumer.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
index cb76de0..906872b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
@@ -65,6 +65,7 @@
         "//compiler/src/iree/compiler/Codegen/Common",
         "//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis",
         "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
         "//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index c4f0801..98d7091 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -68,6 +68,7 @@
     iree::compiler::Codegen::Common
     iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Common::VectorLayoutAnalysis
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
     iree::compiler::Codegen::Interfaces::BufferizationInterfaces
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index e7ae1ef..471d40b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -14,6 +14,7 @@
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -24,6 +25,7 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
@@ -36,7 +38,9 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -47,9 +51,15 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/CSE.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -419,6 +429,212 @@
   transform::modifiesPayload(effects);
 }
 
+//===---------------------------------------------------------------------===//
+// FuseForallOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<int64_t> getTripCount(scf::ForallOp loop) {
+  ArrayRef<int64_t> lbs = loop.getStaticLowerBound();
+  ArrayRef<int64_t> ubs = loop.getStaticUpperBound();
+  ArrayRef<int64_t> steps = loop.getStaticStep();
+
+  if (ShapedType::isDynamicShape(lbs) || ShapedType::isDynamicShape(ubs) ||
+      ShapedType::isDynamicShape(steps)) {
+    return failure();
+  }
+
+  int64_t tripCount = 1;
+  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
+    tripCount *= mlir::ceilDiv((ub - lb), step);
+  }
+  return tripCount;
+}
+
+LogicalResult compareWorkerCountsAndTypes(scf::ForallOp producer,
+                                          scf::ForallOp consumer) {
+  FailureOr<int64_t> producerTripCount = getTripCount(producer);
+  FailureOr<int64_t> consumerTripCount = getTripCount(consumer);
+  if (failed(producerTripCount) || failed(consumerTripCount) ||
+      *producerTripCount != *consumerTripCount) {
+    return failure();
+  }
+
+  auto checkMappingTypes = [&](ArrayAttr array) {
+    return llvm::all_of(array.getValue(),
+                        llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
+           llvm::all_of(array.getValue(),
+                        llvm::IsaPred<gpu::GPUWarpMappingAttr>);
+  };
+
+  if (producer.getMappingAttr() != consumer.getMappingAttr() ||
+      !checkMappingTypes(producer.getMappingAttr()) ||
+      !checkMappingTypes(consumer.getMappingAttr())) {
+    return failure();
+  }
+  return success();
+}
+
+Value getReplacementSlice(RewriterBase &rewriter, Location loc,
+                          tensor::ParallelInsertSliceOp parallelInsert,
+                          tensor::ExtractSliceOp extractSlice,
+                          std::optional<Attribute> addressSpace) {
+  RankedTensorType destTensorType = parallelInsert.getDestType();
+  MemRefType allocType =
+      addressSpace ? MemRefType::get(destTensorType.getShape(),
+                                     destTensorType.getElementType(),
+                                     MemRefLayoutAttrInterface{}, *addressSpace)
+                   : MemRefType::get(destTensorType.getShape(),
+                                     destTensorType.getElementType());
+  Value dest = Value();
+  if (auto empty = parallelInsert.getDest().getDefiningOp<tensor::EmptyOp>()) {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(empty);
+    dest = rewriter.create<memref::AllocOp>(loc, allocType,
+                                            empty.getDynamicSizes());
+  } else {
+    dest = rewriter.create<bufferization::ToMemrefOp>(loc, allocType,
+                                                      parallelInsert.getDest());
+  }
+  return rewriter.create<IREE::GPU::ShuffleTensorOp>(
+      loc, extractSlice.getType(), parallelInsert.getSource(),
+      parallelInsert.getOffsets(), parallelInsert.getSizes(),
+      parallelInsert.getStrides(), parallelInsert.getStaticOffsets(),
+      parallelInsert.getStaticSizes(), parallelInsert.getStaticStrides(), dest,
+      extractSlice.getOffsets(), extractSlice.getSizes(),
+      extractSlice.getStrides(), extractSlice.getStaticOffsets(),
+      extractSlice.getStaticSizes(), extractSlice.getStaticStrides());
+}
+
+LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
+                                  scf::ForallOp producer,
+                                  scf::ForallOp consumer,
+                                  tensor::ExtractSliceOp slice,
+                                  std::optional<Attribute> addressSpace) {
+  if (producer->getNumResults() != 1) {
+    return failure();
+  }
+
+  if (failed(compareWorkerCountsAndTypes(producer, consumer))) {
+    return failure();
+  }
+
+  auto isAll = [](ArrayRef<OpFoldResult> array, int64_t cmp) {
+    return llvm::all_of(array, [cmp](OpFoldResult val) {
+      return isConstantIntValue(val, cmp);
+    });
+  };
+
+  if (!isAll(producer.getMixedStep(), 1) ||
+      !isAll(producer.getMixedLowerBound(), 0) ||
+      !isAll(consumer.getMixedStep(), 1) ||
+      !isAll(consumer.getMixedLowerBound(), 0)) {
+    return failure();
+  }
+
+  rewriter.setInsertionPoint(slice);
+
+  // Step 1. Compute the producer IDs in terms of the consumer IDs.
+
+  MLIRContext *context = rewriter.getContext();
+  Location loc = producer.getLoc();
+
+  AffineExpr d0, d1, d2;
+  bindDims(context, d0, d1, d2);
+  AffineExpr mulAdd = d0 * d1 + d2;
+  OpFoldResult linearId = rewriter.getIndexAttr(0);
+  for (auto [inductionVar, workerCount] :
+       llvm::zip_equal(getAsOpFoldResult(consumer.getInductionVars()),
+                       consumer.getMixedUpperBound())) {
+    linearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, mulAdd, {linearId, workerCount, inductionVar});
+  }
+
+  Value linearThreadIdVal =
+      getValueOrCreateConstantIndexOp(rewriter, loc, linearId);
+  SmallVector<Value> ranges;
+  for (auto workerCount : producer.getStaticUpperBound()) {
+    ranges.push_back(rewriter.create<arith::ConstantIndexOp>(loc, workerCount));
+  }
+  ValueRange newIds = rewriter
+                          .create<affine::AffineDelinearizeIndexOp>(
+                              loc, linearThreadIdVal, ranges)
+                          .getResults();
+
+  // Step 2. Inline the region of the producer.
+  SmallVector<Value> bbArgReplacements(newIds);
+  bbArgReplacements.append(producer.getOutputs().begin(),
+                           producer.getOutputs().end());
+
+  scf::InParallelOp terminator = producer.getTerminator();
+  rewriter.inlineBlockBefore(producer.getBody(), slice, bbArgReplacements);
+
+  rewriter.setInsertionPointAfter(terminator);
+  auto parallelInsert =
+      cast<tensor::ParallelInsertSliceOp>(*terminator.getYieldingOps().begin());
+
+  Value replacementSlice =
+      getReplacementSlice(rewriter, loc, parallelInsert, slice, addressSpace);
+  rewriter.replaceAllUsesWith(slice, replacementSlice);
+
+  rewriter.eraseOp(parallelInsert);
+  rewriter.eraseOp(slice);
+  rewriter.eraseOp(terminator);
+  rewriter.eraseOp(producer);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform_dialect::FuseForallOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  auto producers = state.getPayloadOps(getProducer());
+  auto consumers = state.getPayloadOps(getConsumer());
+
+  int64_t numProducers = llvm::range_size(producers);
+  int64_t numConsumers = llvm::range_size(consumers);
+  if (numProducers != 1 || numConsumers != 1) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "More than one producer or consumer");
+  }
+
+  auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
+  auto consumer = dyn_cast<scf::ForallOp>(*consumers.begin());
+  if (!producer || !consumer) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "Non-forall producer or consumer");
+  }
+
+  if (!producer->hasOneUse()) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "non-single use producer");
+  }
+
+  auto sliceConsumer =
+      dyn_cast<tensor::ExtractSliceOp>(*producer->user_begin());
+  if (!sliceConsumer || sliceConsumer->getParentOp() != consumer) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "producer loop sole consumer is not an "
+                                     "extracted slice from the consumer loop");
+  }
+
+  if (failed(fuseForallIntoSlice(rewriter, producer, consumer, sliceConsumer,
+                                 getAddressSpace()))) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "failed to fuse forall ops");
+  }
+
+  results.set(getOperation()->getOpResult(0), {consumer});
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::FuseForallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getProducer(), effects);
+  transform::consumesHandle(getConsumer(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // GpuDistributeSharedMemoryCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 13fd709..62f3000 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -307,6 +307,42 @@
   }];
 }
 
+def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
+    [FunctionalStyleTransformOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses a producer-consumer pair of scf.forall ops that share the same
+    iterator mapping types and trip counts. An allocation is created to
+    bridge the `parallel.insert_slice` of the producer with the per-thread
+    `extract_slice` of the consumer. If specified, uses |address_space| for
+    the intermediate allocation.
+
+    NOTE: This pattern implicitly REQUIRES that the resulting scf.forall
+    is capable of synchronizing all threads at the point of fusion (i.e.
+    inserting a barrier). This invalidates certain kinds of lowerings of
+    scf.forall ops such as lowering it to loops.
+
+    #### Return modes
+    Emits a definite failure if either the producer or consumer are not
+    scf.forall ops.
+  }];
+
+  let arguments = (
+      ins TransformHandleTypeInterface:$producer,
+          TransformHandleTypeInterface:$consumer,
+          OptionalAttr<AnyAttr>:$address_space
+  );
+  let results = (outs TransformHandleTypeInterface:$result);
+
+  let assemblyFormat = [{
+    $producer `into` $consumer attr-dict
+    `:` functional-type(operands, results)
+  }];
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
 def IREEApplyLoopIndependentCodeMotionOp : Op<Transform_Dialect, "iree.apply_licm",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformEachOpTrait,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 5b09cad..f985335 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -66,6 +66,7 @@
             "tile_and_distribute_to_workgroups.mlir",
             "transform_buffer_opt.mlir",
             "transform_copy_operand.mlir",
+            "transform_fuse_forall.mlir",
             "transform_match_partial_reduction.mlir",
             "transform_ops_invalid.mlir",
             "transpose_canonicalization.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index b284856..a6116a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -62,6 +62,7 @@
     "tile_and_distribute_to_workgroups.mlir"
     "transform_buffer_opt.mlir"
     "transform_copy_operand.mlir"
+    "transform_fuse_forall.mlir"
     "transform_match_partial_reduction.mlir"
     "transform_ops_invalid.mlir"
     "transpose_canonicalization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir
new file mode 100644
index 0000000..e4b991a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir
@@ -0,0 +1,119 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+  func.func @fuse_forall(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %0 = tensor.empty() : tensor<128x128xf32>
+    %2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+      %4 = affine.apply #map(%arg5)
+      %extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+      %5 = linalg.copy ins(%extracted_slice : tensor<2x128xf32>) outs(%extracted_slice_0 : tensor<2x128xf32>) -> tensor<2x128xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
+      }
+    } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+    %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+      %6 = affine.apply #map1(%arg5)
+      %7 = affine.apply #map1(%arg6)
+      %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+      %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+      }
+    } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+    return %3 : tensor<128x128xf32>
+  }
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+    %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.iree.fuse_forall %producer into %consumer {address_space = #gpu.address_space<workgroup>} : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK-LABEL: func @fuse_forall
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
+//   CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<128x128xf32, #gpu.address_space<workgroup>>
+//       CHECK:   scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
+//   CHECK-DAG:     %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
+//   CHECK-DAG:     %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]])
+//       CHECK:     %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
+//       CHECK:     %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c64, %c1) : index, index
+//       CHECK:     %[[INID0:.+]] = affine.apply #[[$MAP2]](%[[IDS]]#0)
+//       CHECK:     %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+//       CHECK:     %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+//       CHECK:     %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
+//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %[[COPY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1]
+//  CHECK-SAME:       to %[[ALLOC]] [%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1]
+//  CHECK-SAME:       : tensor<2x128xf32> -> memref<128x128xf32, #gpu.address_space<workgroup>> -> tensor<16x16xf32>
+//       CHECK:     %[[OUTSLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+//       CHECK:     %[[MM:.+]] = linalg.matmul ins(%[[SHUFFLE]], %[[SHUFFLE]] : tensor<16x16xf32>, tensor<16x16xf32>)
+//  CHECK-SAME:       outs(%[[OUTSLICE]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice %[[MM]] into %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+//       CHECK:     }
+//       CHECK:   } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+  func.func @fuse_forall(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
+      %4 = affine.apply #map(%arg5)
+      %extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+      %5 = linalg.copy ins(%extracted_slice : tensor<2x128xf32>) outs(%extracted_slice_0 : tensor<2x128xf32>) -> tensor<2x128xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
+      }
+    } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
+    %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
+      %6 = affine.apply #map1(%arg5)
+      %7 = affine.apply #map1(%arg6)
+      %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+      %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+      }
+    } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
+    return %3 : tensor<128x128xf32>
+  }
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+    %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK-LABEL: func @fuse_forall
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor<128x128xf32>
+
+//       CHECK:   scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[ARG1]]) -> (tensor<128x128xf32>) {
+//       CHECK:     %[[ALLOC:.+]] = bufferization.to_memref %[[ARG1]]
+//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ALLOC]]
+//  CHECK-SAME:       : tensor<2x128xf32> -> memref<128x128xf32> -> tensor<16x16xf32>
+//       CHECK:   } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}