[Codegen][GPU] Add iree_gpu.tensor_barrier op (#17478)

Because tensors have value semantics, operations on tensors (and
vectors) have the potential to freely change places with side effecting
ops like `gpu.barrier`. This adds an `iree_gpu.tensor_barrier` operation
to keep barriers as a part of the SSA chain. This has the added benefit
of automatic DCE of such barriers, and improved analyzability. In the
future we may want to allow this operation to take multiple tensors to
wait on simultaneously.
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 3283e27..f8a64c6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -72,6 +73,7 @@
                 arith::ArithDialect,
                 bufferization::BufferizationDialect,
                 func::FuncDialect,
+                gpu::GPUDialect,
                 IREE::Flow::FlowDialect,
                 IREE::LinalgExt::IREELinalgExtDialect,
                 IREE::Util::UtilDialect,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index 82d14df..b6dde8f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2571,3 +2571,46 @@
 //       CHECK:   %[[C64:.+]] = arith.constant 64 : index
 //       CHECK:   hal.interface.binding.subspan set(0) binding(0)
 //  CHECK-SAME:       memref<64xi4, strided<[1], offset: 128>
+
+// -----
+
+func.func @tensor_barrier() -> vector<2xf32> {
+  %cst = arith.constant dense<0.0> : vector<2xf32>
+  %cst0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %alloc = bufferization.alloc_tensor() : tensor<2xf32>
+  %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
+  %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+  %res = vector.transfer_read %barrier[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
+  return %res : vector<2xf32>
+}
+// CHECK-LABEL: func @tensor_barrier()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<2xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[ALLOC]]
+//  CHECK-NEXT:   gpu.barrier
+//  CHECK-NEXT:   vector.transfer_read %[[ALLOC]]
+
+// -----
+
+func.func @tensor_barrier_in_loop() -> vector<2xf32> {
+  %cst = arith.constant dense<0.0> : vector<2xf32>
+  %cst0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %alloc = bufferization.alloc_tensor() : tensor<2xf32>
+  %loop = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%init = %alloc) -> tensor<2xf32> {
+    %tmp = vector.transfer_write %cst, %init[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
+    %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+    scf.yield %barrier : tensor<2xf32>
+  }
+  %res = vector.transfer_read %loop[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
+  return %res : vector<2xf32>
+}
+// CHECK-LABEL: func @tensor_barrier_in_loop()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<2xf32>
+//       CHECK:   scf.for
+//  CHECK-NEXT:     vector.transfer_write %{{.*}}, %[[ALLOC]]
+//  CHECK-NEXT:     gpu.barrier
+//  CHECK-NEXT:   }
+//       CHECK:   vector.transfer_read %[[ALLOC]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
index 53d772a..a97a0c5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpImplementation.h"
 #include "mlir/Support/LLVM.h"
 
 // clang-format off
@@ -232,4 +231,12 @@
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TensorBarrierOp
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange TensorBarrierOp::getDpsInitsMutable() {
+  return getInputMutable();
+}
+
 } // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 800af39..61f55af 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -11,6 +11,7 @@
 include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td"
 include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/OpBase.td"
@@ -328,6 +329,41 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TensorBarrierOp
+//===----------------------------------------------------------------------===//
+
+def IREEGPU_TensorBarrierOp : Op<IREEGPU_Dialect, "tensor_barrier", [
+    Pure,
+    DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
+    AllTypesMatch<["input", "result"]>,
+    ]> {
+  let summary = "Shuffles a private tensor across a shared allocation";
+  let description = [{
+    This operation acts as a barrier on a tensor value. It takes a single
+    tensor operand and produces an equivalent tensor. This does not have copy
+    and/or data movement semantics and simply represents a barrier on all writes
+    to the input tensor.
+
+    This operation is a no-op when not present in a parallel context. This
+    operation is pure as it only requires synchronization for the value it
+    produces.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input);
+  let results = (outs AnyRankedTensor:$result);
+
+  let assemblyFormat = [{
+    $input attr-dict `:` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    RankedTensorType getInputType() {
+      return getInput().getType();
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index bb8c03b..08b2f17 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -130,3 +130,14 @@
 //  CHECK-SAME:       iterator_types = []
 //  CHECK-SAME:       kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
 //  CHECK-SAME:     : vector<4xf16>, vector<4xf16> into vector<4xf32>
+
+// -----
+
+func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
+  %out = iree_gpu.tensor_barrier %input : tensor<?xf16>
+  return %out : tensor<?xf16>
+}
+
+// CHECK-LABEL: func @tensor_barrier
+//  CHECK-SAME:   %[[INPUT:[A-Za-z0-9]+]]: tensor<?xf16>
+//       CHECK:   iree_gpu.tensor_barrier %[[INPUT]] : tensor<?xf16>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
index 1780dbb..46072ea 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -13,6 +13,27 @@
 )
 
 iree_compiler_cc_library(
+    name = "BufferizationInterfaces",
+    srcs = [
+        "BufferizationInterfaces.cpp",
+    ],
+    hdrs = [
+        "BufferizationInterfaces.h",
+    ],
+    deps = [
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+        "@llvm-project//mlir:BufferizationDialect",
+        "@llvm-project//mlir:BufferizationInterfaces",
+        "@llvm-project//mlir:BufferizationTransforms",
+        "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:MemRefDialect",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+iree_compiler_cc_library(
     name = "GPUTransforms",
     srcs = [
         "Transforms.cpp",
@@ -25,6 +46,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:TransformUtils",
         "@llvm-project//mlir:Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
new file mode 100644
index 0000000..1071d91
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
@@ -0,0 +1,86 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h"
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+
+using mlir::bufferization::AnalysisState;
+using mlir::bufferization::BufferizableOpInterface;
+using mlir::bufferization::BufferizationOptions;
+using mlir::bufferization::BufferRelation;
+using mlir::bufferization::replaceOpWithBufferizedValues;
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+/// Bufferization of iree_gpu.tensor_barrier. Always just bufferizes in place
+/// and replaces with a barrier.
+struct TensorBarrierOpBufferizationInterface
+    : public BufferizableOpInterface::ExternalModel<
+          TensorBarrierOpBufferizationInterface, IREE::GPU::TensorBarrierOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // This op never needs to bufferize to a copy.
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  bufferization::AliasingValueList
+  getAliasingValues(Operation *op, OpOperand &opOperand,
+                    const AnalysisState &state) const {
+    return {{op->getOpResult(0), BufferRelation::Equivalent}};
+  }
+
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                SmallVector<Value> &invocationStack) const {
+    auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+    assert(value == barrierOp.getResult() && "invalid value");
+    auto srcMemrefType = bufferization::getBufferType(barrierOp.getInput(),
+                                                      options, invocationStack);
+    if (failed(srcMemrefType))
+      return failure();
+    return srcMemrefType;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, barrierOp.getInput(), options);
+    if (failed(buffer)) {
+      return failure();
+    }
+
+    rewriter.create<gpu::BarrierOp>(barrierOp.getLoc());
+
+    // This operation bufferizes in place
+    bufferization::replaceOpWithBufferizedValues(rewriter, op, *buffer);
+    return success();
+  }
+};
+
+} // namespace
+
+void registerIREEGPUBufferizationInterfaces(DialectRegistry &registry) {
+  registry.addExtension(
+      +[](MLIRContext *context, IREE::GPU::IREEGPUDialect *dialect) {
+        IREE::GPU::TensorBarrierOp::attachInterface<
+            TensorBarrierOpBufferizationInterface>(*context);
+      });
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h
new file mode 100644
index 0000000..0d3623a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h
@@ -0,0 +1,19 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_BUFFERIZATIONINTERFACES_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_BUFFERIZATIONINTERFACES_H_
+
+#include "mlir/IR/Dialect.h"
+
+namespace mlir::iree_compiler {
+
+// Register all interfaces needed for bufferization.
+void registerIREEGPUBufferizationInterfaces(DialectRegistry &registry);
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_BUFFERIZATIONINTERFACES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
index 07c8f91..9ab5e2e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -12,6 +12,24 @@
 
 iree_cc_library(
   NAME
+    BufferizationInterfaces
+  HDRS
+    "BufferizationInterfaces.h"
+  SRCS
+    "BufferizationInterfaces.cpp"
+  DEPS
+    MLIRBufferizationDialect
+    MLIRBufferizationTransforms
+    MLIRGPUDialect
+    MLIRIR
+    MLIRMemRefDialect
+    MLIRSupport
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     GPUTransforms
   HDRS
     "Transforms.h"
@@ -21,6 +39,7 @@
     LLVMSupport
     MLIRArithDialect
     MLIRFuncDialect
+    MLIRGPUDialect
     MLIRIR
     MLIRTransformUtils
     MLIRTransforms
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
index d1af546..5c11c61 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
@@ -87,6 +87,7 @@
         "BufferizationInterfaces.h",
     ],
     deps = [
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:BufferizationInterfaces",
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index 614499d..0422f16 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -604,6 +605,7 @@
   vector::registerBufferizableOpInterfaceExternalModels(registry);
 
   // Register IREE operations.
+  registerIREEGPUBufferizationInterfaces(registry);
   registry.addExtension(
       +[](MLIRContext *ctx, IREE::Flow::FlowDialect *dialect) {
         // DispatchTensorLoadOp
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index 4367260..53e91f5 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -74,6 +74,7 @@
     MLIRTensorDialect
     MLIRTensorTransforms
     MLIRVectorTransforms
+    iree::compiler::Codegen::Dialect::GPU::Transforms::BufferizationInterfaces
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR