[Codegen] Add an op for fusing forall ops (#17279)
This fuses a single result, single use producer scf.forall op with
a specified consumer, inserting an `iree_gpu.tensor_shuffle` at the
boundary and remapping the thread indices of the producer to the basis
of the consumer.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
index cb76de0..906872b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
@@ -65,6 +65,7 @@
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index c4f0801..98d7091 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -68,6 +68,7 @@
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Common::VectorLayoutAnalysis
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index e7ae1ef..471d40b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -24,6 +25,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
@@ -36,7 +38,9 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -47,9 +51,15 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -419,6 +429,212 @@
transform::modifiesPayload(effects);
}
+//===---------------------------------------------------------------------===//
+// FuseForallOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<int64_t> getTripCount(scf::ForallOp loop) {
+ ArrayRef<int64_t> lbs = loop.getStaticLowerBound();
+ ArrayRef<int64_t> ubs = loop.getStaticUpperBound();
+ ArrayRef<int64_t> steps = loop.getStaticStep();
+
+ if (ShapedType::isDynamicShape(lbs) || ShapedType::isDynamicShape(ubs) ||
+ ShapedType::isDynamicShape(steps)) {
+ return failure();
+ }
+
+ int64_t tripCount = 1;
+ for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
+ tripCount *= mlir::ceilDiv((ub - lb), step);
+ }
+ return tripCount;
+}
+
+LogicalResult compareWorkerCountsAndTypes(scf::ForallOp producer,
+ scf::ForallOp consumer) {
+ FailureOr<int64_t> producerTripCount = getTripCount(producer);
+ FailureOr<int64_t> consumerTripCount = getTripCount(consumer);
+ if (failed(producerTripCount) || failed(consumerTripCount) ||
+ *producerTripCount != *consumerTripCount) {
+ return failure();
+ }
+
+ auto checkMappingTypes = [&](ArrayAttr array) {
+ return llvm::all_of(array.getValue(),
+ llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
+ llvm::all_of(array.getValue(),
+ llvm::IsaPred<gpu::GPUWarpMappingAttr>);
+ };
+
+ if (producer.getMappingAttr() != consumer.getMappingAttr() ||
+ !checkMappingTypes(producer.getMappingAttr()) ||
+ !checkMappingTypes(consumer.getMappingAttr())) {
+ return failure();
+ }
+ return success();
+}
+
+Value getReplacementSlice(RewriterBase &rewriter, Location loc,
+ tensor::ParallelInsertSliceOp parallelInsert,
+ tensor::ExtractSliceOp extractSlice,
+ std::optional<Attribute> addressSpace) {
+ RankedTensorType destTensorType = parallelInsert.getDestType();
+ MemRefType allocType =
+ addressSpace ? MemRefType::get(destTensorType.getShape(),
+ destTensorType.getElementType(),
+ MemRefLayoutAttrInterface{}, *addressSpace)
+ : MemRefType::get(destTensorType.getShape(),
+ destTensorType.getElementType());
+ Value dest = Value();
+ if (auto empty = parallelInsert.getDest().getDefiningOp<tensor::EmptyOp>()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(empty);
+ dest = rewriter.create<memref::AllocOp>(loc, allocType,
+ empty.getDynamicSizes());
+ } else {
+ dest = rewriter.create<bufferization::ToMemrefOp>(loc, allocType,
+ parallelInsert.getDest());
+ }
+ return rewriter.create<IREE::GPU::ShuffleTensorOp>(
+ loc, extractSlice.getType(), parallelInsert.getSource(),
+ parallelInsert.getOffsets(), parallelInsert.getSizes(),
+ parallelInsert.getStrides(), parallelInsert.getStaticOffsets(),
+ parallelInsert.getStaticSizes(), parallelInsert.getStaticStrides(), dest,
+ extractSlice.getOffsets(), extractSlice.getSizes(),
+ extractSlice.getStrides(), extractSlice.getStaticOffsets(),
+ extractSlice.getStaticSizes(), extractSlice.getStaticStrides());
+}
+
+LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
+ scf::ForallOp producer,
+ scf::ForallOp consumer,
+ tensor::ExtractSliceOp slice,
+ std::optional<Attribute> addressSpace) {
+ if (producer->getNumResults() != 1) {
+ return failure();
+ }
+
+ if (failed(compareWorkerCountsAndTypes(producer, consumer))) {
+ return failure();
+ }
+
+ auto isAll = [](ArrayRef<OpFoldResult> array, int64_t cmp) {
+ return llvm::all_of(array, [cmp](OpFoldResult val) {
+ return isConstantIntValue(val, cmp);
+ });
+ };
+
+ if (!isAll(producer.getMixedStep(), 1) ||
+ !isAll(producer.getMixedLowerBound(), 0) ||
+ !isAll(consumer.getMixedStep(), 1) ||
+ !isAll(consumer.getMixedLowerBound(), 0)) {
+ return failure();
+ }
+
+ rewriter.setInsertionPoint(slice);
+
+ // Step 1. Compute the producer IDs in terms of the consumer IDs.
+
+ MLIRContext *context = rewriter.getContext();
+ Location loc = producer.getLoc();
+
+ AffineExpr d0, d1, d2;
+ bindDims(context, d0, d1, d2);
+ AffineExpr mulAdd = d0 * d1 + d2;
+ OpFoldResult linearId = rewriter.getIndexAttr(0);
+ for (auto [inductionVar, workerCount] :
+ llvm::zip_equal(getAsOpFoldResult(consumer.getInductionVars()),
+ consumer.getMixedUpperBound())) {
+ linearId = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, mulAdd, {linearId, workerCount, inductionVar});
+ }
+
+ Value linearThreadIdVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearId);
+ SmallVector<Value> ranges;
+ for (auto workerCount : producer.getStaticUpperBound()) {
+ ranges.push_back(rewriter.create<arith::ConstantIndexOp>(loc, workerCount));
+ }
+ ValueRange newIds = rewriter
+ .create<affine::AffineDelinearizeIndexOp>(
+ loc, linearThreadIdVal, ranges)
+ .getResults();
+
+ // Step 2. Inline the region of the producer.
+ SmallVector<Value> bbArgReplacements(newIds);
+ bbArgReplacements.append(producer.getOutputs().begin(),
+ producer.getOutputs().end());
+
+ scf::InParallelOp terminator = producer.getTerminator();
+ rewriter.inlineBlockBefore(producer.getBody(), slice, bbArgReplacements);
+
+ rewriter.setInsertionPointAfter(terminator);
+ auto parallelInsert =
+ cast<tensor::ParallelInsertSliceOp>(*terminator.getYieldingOps().begin());
+
+ Value replacementSlice =
+ getReplacementSlice(rewriter, loc, parallelInsert, slice, addressSpace);
+ rewriter.replaceAllUsesWith(slice, replacementSlice);
+
+ rewriter.eraseOp(parallelInsert);
+ rewriter.eraseOp(slice);
+ rewriter.eraseOp(terminator);
+ rewriter.eraseOp(producer);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform_dialect::FuseForallOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto producers = state.getPayloadOps(getProducer());
+ auto consumers = state.getPayloadOps(getConsumer());
+
+ int64_t numProducers = llvm::range_size(producers);
+ int64_t numConsumers = llvm::range_size(consumers);
+ if (numProducers != 1 || numConsumers != 1) {
+ return mlir::emitDefiniteFailure(state.getTopLevel(),
+ "More than one producer or consumer");
+ }
+
+ auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
+ auto consumer = dyn_cast<scf::ForallOp>(*consumers.begin());
+ if (!producer || !consumer) {
+ return mlir::emitDefiniteFailure(state.getTopLevel(),
+ "Non-forall producer or consumer");
+ }
+
+ if (!producer->hasOneUse()) {
+ return mlir::emitDefiniteFailure(state.getTopLevel(),
+ "non-single use producer");
+ }
+
+ auto sliceConsumer =
+ dyn_cast<tensor::ExtractSliceOp>(*producer->user_begin());
+ if (!sliceConsumer || sliceConsumer->getParentOp() != consumer) {
+ return mlir::emitDefiniteFailure(state.getTopLevel(),
+ "producer loop sole consumer is not an "
+ "extracted slice from the consumer loop");
+ }
+
+ if (failed(fuseForallIntoSlice(rewriter, producer, consumer, sliceConsumer,
+ getAddressSpace()))) {
+ return mlir::emitDefiniteFailure(state.getTopLevel(),
+ "failed to fuse forall ops");
+ }
+
+ results.set(getOperation()->getOpResult(0), {consumer});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::FuseForallOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getProducer(), effects);
+ transform::consumesHandle(getConsumer(), effects);
+ transform::producesHandle(getResult(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// GpuDistributeSharedMemoryCopyOp
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 13fd709..62f3000 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -307,6 +307,42 @@
}];
}
+def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
+ [FunctionalStyleTransformOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Fuses a producer-consumer pair of scf.forall ops that share the same
+ iterator mapping types and trip counts. An allocation is created to
+ bridge the `parallel.insert_slice` of the producer with the per-thread
+ `extract_slice` of the consumer. If specified, uses |address_space| for
+ the intermediate allocation.
+
+ NOTE: This pattern implicitly REQUIRES that the resulting scf.forall
+ is capable of synchronizing all threads at the point of fusion (i.e.
+ inserting a barrier). This invalidates certain kinds of lowerings of
+ scf.forall ops such as lowering it to loops.
+
+ #### Return modes
+ Emits a definite failure if either the producer or consumer are not
+ scf.forall ops.
+ }];
+
+ let arguments = (
+ ins TransformHandleTypeInterface:$producer,
+ TransformHandleTypeInterface:$consumer,
+ OptionalAttr<AnyAttr>:$address_space
+ );
+ let results = (outs TransformHandleTypeInterface:$result);
+
+ let assemblyFormat = [{
+ $producer `into` $consumer attr-dict
+ `:` functional-type(operands, results)
+ }];
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
def IREEApplyLoopIndependentCodeMotionOp : Op<Transform_Dialect, "iree.apply_licm",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 5b09cad..f985335 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -66,6 +66,7 @@
"tile_and_distribute_to_workgroups.mlir",
"transform_buffer_opt.mlir",
"transform_copy_operand.mlir",
+ "transform_fuse_forall.mlir",
"transform_match_partial_reduction.mlir",
"transform_ops_invalid.mlir",
"transpose_canonicalization.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index b284856..a6116a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -62,6 +62,7 @@
"tile_and_distribute_to_workgroups.mlir"
"transform_buffer_opt.mlir"
"transform_copy_operand.mlir"
+ "transform_fuse_forall.mlir"
"transform_match_partial_reduction.mlir"
"transform_ops_invalid.mlir"
"transpose_canonicalization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir
new file mode 100644
index 0000000..e4b991a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_fuse_forall.mlir
@@ -0,0 +1,119 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @fuse_forall(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+ %4 = affine.apply #map(%arg5)
+ %extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+ %5 = linalg.copy ins(%extracted_slice : tensor<2x128xf32>) outs(%extracted_slice_0 : tensor<2x128xf32>) -> tensor<2x128xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+ %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+ %6 = affine.apply #map1(%arg5)
+ %7 = affine.apply #map1(%arg6)
+ %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+ return %3 : tensor<128x128xf32>
+ }
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.fuse_forall %producer into %consumer {address_space = #gpu.address_space<workgroup>} : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK-LABEL: func @fuse_forall
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<128x128xf32, #gpu.address_space<workgroup>>
+// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
+// CHECK-DAG: %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
+// CHECK-DAG: %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]])
+// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
+// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c64, %c1) : index, index
+// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP2]](%[[IDS]]#0)
+// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %[[COPY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1]
+// CHECK-SAME: to %[[ALLOC]] [%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1]
+// CHECK-SAME: : tensor<2x128xf32> -> memref<128x128xf32, #gpu.address_space<workgroup>> -> tensor<16x16xf32>
+// CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%[[SHUFFLE]], %[[SHUFFLE]] : tensor<16x16xf32>, tensor<16x16xf32>)
+// CHECK-SAME: outs(%[[OUTSLICE]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MM]] into %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+// CHECK: }
+// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @fuse_forall(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
+ %4 = affine.apply #map(%arg5)
+ %extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+ %5 = linalg.copy ins(%extracted_slice : tensor<2x128xf32>) outs(%extracted_slice_0 : tensor<2x128xf32>) -> tensor<2x128xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
+ %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
+ %6 = affine.apply #map1(%arg5)
+ %7 = affine.apply #map1(%arg6)
+ %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
+ return %3 : tensor<128x128xf32>
+ }
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK-LABEL: func @fuse_forall
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<128x128xf32>
+
+// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[ARG1]]) -> (tensor<128x128xf32>) {
+// CHECK: %[[ALLOC:.+]] = bufferization.to_memref %[[ARG1]]
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ALLOC]]
+// CHECK-SAME: : tensor<2x128xf32> -> memref<128x128xf32> -> tensor<16x16xf32>
+// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}