[Codegen] Add pattern for lowering iree_gpu.shuffle_tensor (#17269)

This is a direct lowering that relies on bufferization to produce the
correct allocation and inserts a barrier between the slices. This is
an experimental lowering that plays fast and loose with barrier placement
(as does all barrier usage within the compiler).
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 471d40b..5f2726b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -51,8 +51,6 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -177,6 +175,57 @@
 }
 
 //===---------------------------------------------------------------------===//
+// ApplyLowerShuffleTensorPatternsOp
+//===---------------------------------------------------------------------===//
+
+namespace {
+struct LowerShuffleTensor
+    : public OpRewritePattern<IREE::GPU::ShuffleTensorOp> {
+  using OpRewritePattern<IREE::GPU::ShuffleTensorOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(IREE::GPU::ShuffleTensorOp shuffleOp,
+                                PatternRewriter &rewriter) const final {
+    Location loc = shuffleOp.getLoc();
+
+    MemRefType allocType = shuffleOp.getSharedAllocType();
+    auto tensorType =
+        RankedTensorType::get(allocType.getShape(), allocType.getElementType());
+    Value tensorAlloc = rewriter.create<bufferization::ToTensorOp>(
+        loc, tensorType, shuffleOp.getSharedAlloc(), /*restrict=*/true,
+        /*writeable=*/true);
+
+    // Step 1. Insert the source slice into the intermediate tensor.
+    SmallVector<OpFoldResult, 4> sourceOffsets =
+        shuffleOp.getMixedSourceOffsets();
+    SmallVector<OpFoldResult, 4> sourceSizes = shuffleOp.getMixedSourceSizes();
+    SmallVector<OpFoldResult, 4> sourceStrides =
+        shuffleOp.getMixedSourceStrides();
+    Value insertedSlice = rewriter.create<tensor::InsertSliceOp>(
+        loc, shuffleOp.getSource(), tensorAlloc, sourceOffsets, sourceSizes,
+        sourceStrides);
+
+    // Step 2. Synchronize the workers.
+    rewriter.create<gpu::BarrierOp>(loc);
+
+    // Step 3. Extract the result slice.
+    SmallVector<OpFoldResult, 4> resultOffsets =
+        shuffleOp.getMixedResultOffsets();
+    SmallVector<OpFoldResult, 4> resultSizes = shuffleOp.getMixedResultSizes();
+    SmallVector<OpFoldResult, 4> resultStrides =
+        shuffleOp.getMixedResultStrides();
+    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+        shuffleOp, shuffleOp.getType(), insertedSlice, resultOffsets,
+        resultSizes, resultStrides);
+    return success();
+  }
+};
+} // namespace
+
+void transform_dialect::ApplyLowerShuffleTensorPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<LowerShuffleTensor>(patterns.getContext());
+}
+
+//===---------------------------------------------------------------------===//
 // ApplyUnrollVectorsGpuMmaSyncPatternsOp
 //===---------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 62f3000..f2a795b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -114,6 +114,19 @@
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyLowerShuffleTensorPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.lower_shuffle_tensor",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Populate patterns that lowers iree_gpu.shuffle_tensor ops to allocations
+    and copies.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyPrepareVectorToMMAPatternsOp : Op<Transform_Dialect,
     "apply_patterns.iree.prepare_vector_to_mma",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index f985335..bba1d56 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -67,6 +67,7 @@
             "transform_buffer_opt.mlir",
             "transform_copy_operand.mlir",
             "transform_fuse_forall.mlir",
+            "transform_lower_shuffle.mlir",
             "transform_match_partial_reduction.mlir",
             "transform_ops_invalid.mlir",
             "transpose_canonicalization.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index a6116a7..3c3bdff 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -63,6 +63,7 @@
     "transform_buffer_opt.mlir"
     "transform_copy_operand.mlir"
     "transform_fuse_forall.mlir"
+    "transform_lower_shuffle.mlir"
     "transform_match_partial_reduction.mlir"
     "transform_ops_invalid.mlir"
     "transpose_canonicalization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_lower_shuffle.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_lower_shuffle.mlir
new file mode 100644
index 0000000..2413151
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_lower_shuffle.mlir
@@ -0,0 +1,59 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+func.func @shuffle_tensor(%init: memref<6x6xf32>, %arg0: tensor<2x3xf32>, %x: index) -> tensor<3x2xf32> {
+  %0 = iree_gpu.shuffle_tensor %arg0[%x, 0] [2, 3] [1, 1] to %init[0, %x] [3, 2] [1, 1] : tensor<2x3xf32> -> memref<6x6xf32> -> tensor<3x2xf32>
+  return %0 : tensor<3x2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.lower_shuffle_tensor
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @shuffle_tensor
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9]+]]: memref<6x6xf32>
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:   %[[X:[A-Za-z0-9]+]]: index
+
+//       CHECK:   %[[TENSOR:.+]] = bufferization.to_tensor %[[INIT]]
+//  CHECK-SAME:     restrict
+//  CHECK-SAME:     writable
+//       CHECK:   %[[IN:.+]] = tensor.insert_slice %[[ARG1]] into %[[TENSOR]][%[[X]], 0] [2, 3] [1, 1] : tensor<2x3xf32> into tensor<6x6xf32>
+//       CHECK:   gpu.barrier
+//       CHECK:   %[[OUT:.+]] = tensor.extract_slice %[[IN]][0, %[[X]]] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
+//       CHECK:   return %[[OUT]] : tensor<3x2xf32>
+
+// -----
+
+func.func @rank_reducing_shuffle_tensor(%init: memref<1x6x6xf32>, %arg0: tensor<2x3xf32>, %x: index, %y: index) -> tensor<3x2xf32> {
+  %0 = iree_gpu.shuffle_tensor %arg0[0, %x, %y] [1, 2, 3] [1, 1, 1] to %init[0, %y, %x] [1, 3, 2] [1, 1, 1] : tensor<2x3xf32> -> memref<1x6x6xf32> -> tensor<3x2xf32>
+  return %0 : tensor<3x2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.lower_shuffle_tensor
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @rank_reducing_shuffle_tensor
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9]+]]: memref<1x6x6xf32>
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:   %[[X:[A-Za-z0-9]+]]: index
+//  CHECK-SAME:   %[[Y:[A-Za-z0-9]+]]: index
+
+//       CHECK:   %[[TENSOR:.+]] = bufferization.to_tensor %[[INIT]]
+//  CHECK-SAME:     restrict
+//  CHECK-SAME:     writable
+//       CHECK:   %[[IN:.+]] = tensor.insert_slice %[[ARG1]] into %[[TENSOR]][0, %[[X]], %[[Y]]] [1, 2, 3] [1, 1, 1] : tensor<2x3xf32> into tensor<1x6x6xf32>
+//       CHECK:   gpu.barrier
+//       CHECK:   tensor.extract_slice %[[IN]][0, %[[Y]], %[[X]]] [1, 3, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<3x2xf32>