[Codegen] Add iree_gpu.shuffle_tensor op (#17257)

This adds an operation for shuffling tensors when fusing two parallel
loops together with tensor results. The `iree_gpu.shuffle_tensor` op can
be thought of generally as a `tensor.insert_slice` from the source
tensor to an intermediate allocation, and then a `tensor.extract_slice`
from that intermediate allocation.

This can be broken down directly into a sequence of vector transfer ops
and a `memref.alloc`, however having the operation allows a more gradual
lowering and better control in the future when considering
transformations like pipelining by not having to jump straight to
memrefs + barriers.

This just adds the op structure. Patterns for lowering and generating
this operation will be added in subsequent patches.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel
index 558c806..4946502 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel
@@ -26,11 +26,13 @@
             "IREEGPUAttrs.td",
             "IREEGPUDialect.td",
             "IREEGPUInterfaces.td",
+            "IREEGPUOps.td",
         ],
         include = ["*.td"],
     ),
     deps = [
         "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
     ],
 )
 
@@ -40,11 +42,13 @@
         "IREEGPUAttrs.cpp",
         "IREEGPUDialect.cpp",
         "IREEGPUInterfaces.cpp",
+        "IREEGPUOps.cpp",
     ],
     hdrs = [
         "IREEGPUAttrs.h",
         "IREEGPUDialect.h",
         "IREEGPUInterfaces.h",
+        "IREEGPUOps.h",
     ],
     textual_hdrs = [
         "IREEGPUAttrs.cpp.inc",
@@ -53,11 +57,14 @@
         "IREEGPUDialect.h.inc",
         "IREEGPUInterfaces.cpp.inc",
         "IREEGPUInterfaces.h.inc",
+        "IREEGPUOps.cpp.inc",
+        "IREEGPUOps.h.inc",
     ],
     deps = [
         ":IREEGPUAttrs",
         ":IREEGPUDialectGen",
         ":IREEGPUInterfaces",
+        ":IREEGPUOpsGen",
         "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
         "//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
         "@llvm-project//llvm:Support",
@@ -65,7 +72,9 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:SideEffectInterfaces",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:VectorDialect",
     ],
 )
@@ -128,3 +137,22 @@
     td_file = "IREEGPUInterfaces.td",
     deps = [":td_files"],
 )
+
+iree_gentbl_cc_library(
+    name = "IREEGPUOpsGen",
+    tbl_outs = [
+        (
+            ["--gen-op-decls"],
+            "IREEGPUOps.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "IREEGPUOps.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "IREEGPUOps.td",
+    deps = [
+        ":td_files",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt
index ee53a41..0a98f44 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt
@@ -17,6 +17,7 @@
     "IREEGPUAttrs.h"
     "IREEGPUDialect.h"
     "IREEGPUInterfaces.h"
+    "IREEGPUOps.h"
   TEXTUAL_HDRS
     "IREEGPUAttrs.cpp.inc"
     "IREEGPUAttrs.h.inc"
@@ -24,20 +25,26 @@
     "IREEGPUDialect.h.inc"
     "IREEGPUInterfaces.cpp.inc"
     "IREEGPUInterfaces.h.inc"
+    "IREEGPUOps.cpp.inc"
+    "IREEGPUOps.h.inc"
   SRCS
     "IREEGPUAttrs.cpp"
     "IREEGPUDialect.cpp"
     "IREEGPUInterfaces.cpp"
+    "IREEGPUOps.cpp"
   DEPS
     ::IREEGPUAttrs
     ::IREEGPUDialectGen
     ::IREEGPUInterfaces
+    ::IREEGPUOpsGen
     IREEVectorExtDialect
     LLVMSupport
     MLIRIR
     MLIRLinalgDialect
     MLIRParser
+    MLIRSideEffectInterfaces
     MLIRSupport
+    MLIRTensorDialect
     MLIRVectorDialect
     iree::compiler::Codegen::Utils::VectorOpUtils
   PUBLIC
@@ -75,4 +82,14 @@
     --gen-attr-interface-defs IREEGPUInterfaces.cpp.inc
 )
 
+iree_tablegen_library(
+  NAME
+    IREEGPUOpsGen
+  TD_FILE
+    "IREEGPUOps.td"
+  OUTS
+    --gen-op-decls IREEGPUOps.h.inc
+    --gen-op-defs IREEGPUOps.cpp.inc
+)
+
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.cpp
index 2a4f146..3d22ae1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.cpp
@@ -7,9 +7,17 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
 
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.cpp.inc"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 
 namespace mlir::iree_compiler::IREE::GPU {
 
-void IREEGPUDialect::initialize() { registerAttributes(); }
+void IREEGPUDialect::initialize() {
+  registerAttributes();
+
+  addOperations<
+#define GET_OP_LIST
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp.inc"
+      >();
+}
 
 } // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.td
index 924af7f..281e27e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.td
@@ -21,12 +21,10 @@
     A dialect representing attributes used by GPU focused IREE code generation.
   }];
   let description = [{
-    This dialect is primarily meant to hold attributes that carry additional
-    target specific information expanded based on executable target information.
-
-    This information is only used by codegen to normalize the higher level
-    target details across backends and devices. Late lowerings to SPIR-V/LLVM
-    still rely on the information designed for those targets.
+    This dialect provides operations and attributes to aid in code generation
+    for GPU targets. The functionality in this dialect can be hardware specific,
+    but is intended to be independent of the lowering target. Late lowerings to
+    SPIR-V/LLVM are handled separately.
   }];
   let useDefaultAttributePrinterParser = 1;
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
new file mode 100644
index 0000000..f80f304
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -0,0 +1,63 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+// clang-format off
+#define GET_OP_CLASSES
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp.inc" // IWYU pragma: keep
+// clang-format on
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+//===----------------------------------------------------------------------===//
+// ShuffleTensorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ShuffleTensorOp::verify() {
+  // Get the equivalent tensor type for the alloc to verify against.
+  MemRefType allocType = getSharedAllocType();
+  Type allocElementType = allocType.getElementType();
+  RankedTensorType allocTensorType =
+      RankedTensorType::get(allocType.getShape(), allocElementType);
+
+  // Verify source type against inferred type. Slice insertion and extraction
+  // use the same verification logic.
+  RankedTensorType expectedType = tensor::ExtractSliceOp::inferResultType(
+      allocTensorType, getMixedSourceOffsets(), getMixedSourceSizes(),
+      getMixedSourceStrides());
+  SliceVerificationResult result =
+      isRankReducedType(expectedType, getSourceType());
+  if (result != SliceVerificationResult::Success) {
+    return emitError("Invalid source slice type");
+  }
+
+  // Do the same for the resulting tensor type
+  expectedType = tensor::ExtractSliceOp::inferResultType(
+      allocTensorType, getMixedResultOffsets(), getMixedResultSizes(),
+      getMixedResultStrides());
+  result = isRankReducedType(expectedType, getType());
+  if (result != SliceVerificationResult::Success) {
+    return emitError("Invalid result slice type");
+  }
+
+  if (allocElementType != getSourceType().getElementType() ||
+      allocElementType != getType().getElementType()) {
+    return emitError(
+        "Element type mismatch between source, allocation, and result");
+  }
+
+  // TODO: Verification of the allocation size in the static case.
+  return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h
new file mode 100644
index 0000000..0fa26ce
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h
@@ -0,0 +1,23 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_IREEGPUOPS_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_IREEGPUOPS_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+// clang-format off
+#define GET_OP_CLASSES
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h.inc" // IWYU pragma: export
+// clang-format on
+
+#endif // #ifndef IREE_COMPILER_CODEGEN_DIALECT_IREEGPUOPS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
new file mode 100644
index 0000000..32a2a3b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -0,0 +1,177 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_CODEGEN_DIALECT_IREEGPUOPS
+#define IREE_CODEGEN_DIALECT_IREEGPUOPS
+
+include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShuffleTensorOp
+//===----------------------------------------------------------------------===//
+
+def IREEGPU_ShuffleTensorOp : Op<IREEGPU_Dialect, "shuffle_tensor", [
+    Pure,
+    AttrSizedOperandSegments
+    ]> {
+  let summary = "Shuffles a private tensor across a shared allocation";
+  let description = [{
+    This op is designed to represent a shuffle of private tensor data
+    collectively held across a set of workers. This operation naturally arises
+    when combining the regions of producer-consumer `scf.forall` operations
+    that share a mapping type and worker count.
+
+    For example, consider the following pair of parallel loops.
+    ```mlir
+      %0 = scf.forall (%idy, %idx) in (2, 32) shared_outs(%init = %empty) -> (tensor<4x128xf32>) {
+        %in = ...
+        %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%idy)
+        %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%idx)
+        scf.forall.in_parallel {
+          tensor.parallel_insert_slice %in into %init[%2, %3] [2, 4] [1, 1]
+            : tensor<2x4xf32> into tensor<4x128xf32>
+        }
+      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+      %1 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<128x128xf32>) {
+        %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
+        %extracted_slice = tensor.extract_slice %0[0, %4] [4, 16] [1, 1]
+          : tensor<4x128xf32> to tensor<4x16xf32>
+        ...
+      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+    ```
+
+    Because these loops share the same worker type and total count, the bodies
+    of these two loops can be merged with a barrier and a shuffle where the
+    boundary of the loops currently is.
+
+    ```mlir
+      %alloc = bufferization.to_memref %empty
+      %0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
+        %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
+        %in = ...
+        %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
+        %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
+        %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
+        %slice = iree_gpu.shuffle_tensor %in[%2, %3] [2, 4] [1, 1] to %alloc[0, %4] [4, 16] [1, 1]
+          : tensor<2x4xf32> -> memref<4x128xf32> -> tensor<4x16xf32>
+        ...
+      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+    ```
+
+    A shuffle can be lowered to a shared allocation with a write of the source
+    slice, a barrier, and a read of the result slice. Note that to avoid both
+    conflicting writes, and to execute the barrier, this renders any lowerings
+    of the enclosing `scf.forall` to serial loops invalid. In other words, the
+    lowerings/hardware must provide the number of workers requested by the loop.
+
+    This op takes an input |source| tensor to represent the slice held by this
+    worker before the shuffle, an intermediate memref |shared_alloc| that all
+    workers insert into, and yields a |result| slice of the intermediate memref
+    read by this worker after the shuffle is done.
+
+    It is undefined behavior if the source or result tensor slices are out of
+    bounds of the intermediate allocation.
+
+    Movtivation and Intended Use Cases:
+
+    The primary way this op is generated is when fusing parallel loops with
+    tensor results. This operation helps to make lowerings more progressive
+    and flexible.
+      - Rather than lowering straight to vector ops for the reads/writes
+        for the shuffle, this allows separating out the vectorization of the
+        shared memory accesses from earlier tiling steps.
+      - Lowering directly to an alloc + reads and writes breaks the dependency
+        chain making transformations like barrier placement and pipelining
+        potentially more difficult.
+      - Allows the option of non-vector based lowering paths.
+  }];
+
+  let arguments = (ins
+    AnyRankedTensor:$source,
+    Variadic<Index>:$source_offsets,
+    Variadic<Index>:$source_sizes,
+    Variadic<Index>:$source_strides,
+    DenseI64ArrayAttr:$static_source_offsets,
+    DenseI64ArrayAttr:$static_source_sizes,
+    DenseI64ArrayAttr:$static_source_strides,
+    AnyMemRef:$shared_alloc,
+    Variadic<Index>:$result_offsets,
+    Variadic<Index>:$result_sizes,
+    Variadic<Index>:$result_strides,
+    DenseI64ArrayAttr:$static_result_offsets,
+    DenseI64ArrayAttr:$static_result_sizes,
+    DenseI64ArrayAttr:$static_result_strides
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+
+  let assemblyFormat = [{
+    $source ``
+    custom<DynamicIndexList>($source_offsets, $static_source_offsets)
+    custom<DynamicIndexList>($source_sizes, $static_source_sizes)
+    custom<DynamicIndexList>($source_strides, $static_source_strides)
+    `to` $shared_alloc
+    custom<DynamicIndexList>($result_offsets, $static_result_offsets)
+    custom<DynamicIndexList>($result_sizes, $static_result_sizes)
+    custom<DynamicIndexList>($result_strides, $static_result_strides)
+    attr-dict `:` type($source) `->` type($shared_alloc) `->` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    RankedTensorType getSourceType() {
+      return getSource().getType();
+    }
+
+    MemRefType getSharedAllocType() {
+      return getSharedAlloc().getType();
+    }
+
+    // Because we have two sets of offsets/sizes/strides, we cannot use
+    // interface boilerplate and instead redefine it.
+
+    // Source slice view-like getters.
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedSourceOffsets() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticSourceOffsets(),
+                                    getSourceOffsets(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedSourceSizes() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticSourceSizes(),
+                                    getSourceSizes(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedSourceStrides() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticSourceStrides(),
+                                    getSourceStrides(), b);
+    }
+
+    // Result slice view-like getters.
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedResultOffsets() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticResultOffsets(),
+                                    getResultOffsets(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedResultSizes() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticResultSizes(),
+                                    getResultSizes(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult, 4> getMixedResultStrides() {
+      Builder b(getContext());
+      return ::mlir::getMixedValues(getStaticResultStrides(),
+                                    getResultStrides(), b);
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+#endif // IREE_CODEGEN_DIALECT_IREEGPUOPS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel
index 9e35ca1..93cd134 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel
@@ -18,6 +18,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "iree_gpu_ops.mlir",
             "mma_attrs.mlir",
         ],
         include = ["*.mlir"],
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/CMakeLists.txt
index 45f7830..f760a6b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "iree_gpu_ops.mlir"
     "mma_attrs.mlir"
   TOOLS
     FileCheck
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
new file mode 100644
index 0000000..af96e6e
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -0,0 +1,32 @@
+// RUN: iree-opt %s --split-input-file | FileCheck %s
+
+func.func @shuffle_tensor(%init: memref<6x6xf32>, %arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %0 = iree_gpu.shuffle_tensor %arg0[0, 0] [2, 3] [1, 1] to %init[0, 0] [3, 2] [1, 1] : tensor<2x3xf32> -> memref<6x6xf32> -> tensor<3x2xf32>
+  return %0 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: func @shuffle_tensor
+//       CHECK:   iree_gpu.shuffle_tensor %arg1[0, 0] [2, 3] [1, 1] to
+//  CHECK-SAME:     %arg0 [0, 0] [3, 2] [1, 1] : tensor<2x3xf32> -> memref<6x6xf32> -> tensor<3x2xf32>
+
+// -----
+
+func.func @rank_reducing_shuffle_tensor(%init: memref<1x6x6xf32>, %arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %0 = iree_gpu.shuffle_tensor %arg0[0, 0, 0] [1, 2, 3] [1, 1, 1] to %init[0, 0, 0] [1, 3, 2] [1, 1, 1] : tensor<2x3xf32> -> memref<1x6x6xf32> -> tensor<3x2xf32>
+  return %0 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: func @rank_reducing_shuffle_tensor
+//       CHECK:   iree_gpu.shuffle_tensor %arg1[0, 0, 0] [1, 2, 3] [1, 1, 1] to
+//  CHECK-SAME:     %arg0 [0, 0, 0] [1, 3, 2] [1, 1, 1] : tensor<2x3xf32> -> memref<1x6x6xf32> -> tensor<3x2xf32>
+
+// -----
+
+func.func @dynamic_alloc_shuffle_tensor(%init: memref<?x?xf32>, %arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %0 = iree_gpu.shuffle_tensor %arg0[0, 0] [2, 3] [1, 1] to %init[0, 0] [3, 2] [1, 1] : tensor<2x3xf32> -> memref<?x?xf32> -> tensor<3x2xf32>
+  return %0 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: func @dynamic_alloc_shuffle_tensor
+//       CHECK:   iree_gpu.shuffle_tensor %arg1[0, 0] [2, 3] [1, 1] to
+//  CHECK-SAME:     %arg0 [0, 0] [3, 2] [1, 1] : tensor<2x3xf32> -> memref<?x?xf32> -> tensor<3x2xf32>