[Codegen][GPU] Allow iree_gpu.tensor_barrier to take vectors (#17479)

This allows synchronizing on vectors as well as tensors with similar
semantics. In a typical lowering flow, this will represent the
read equivalent to a tensor barrier, in that a tensor barrier represents
a wait until all writes to a shared allocation has finished, while this
represents a wait until all threads have read the value they need from
that shared allocation.

Renames the operation to iree_gpu.value_barrier for clarity.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index b6dde8f..911a44d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2580,7 +2580,7 @@
   %c0 = arith.constant 0 : index
   %alloc = bufferization.alloc_tensor() : tensor<2xf32>
   %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
-  %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+  %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
   %res = vector.transfer_read %barrier[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
   return %res : vector<2xf32>
 }
@@ -2601,7 +2601,7 @@
   %alloc = bufferization.alloc_tensor() : tensor<2xf32>
   %loop = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%init = %alloc) -> tensor<2xf32> {
     %tmp = vector.transfer_write %cst, %init[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
-    %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+    %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
     scf.yield %barrier : tensor<2xf32>
   }
   %res = vector.transfer_read %loop[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
@@ -2614,3 +2614,23 @@
 //  CHECK-NEXT:     gpu.barrier
 //  CHECK-NEXT:   }
 //       CHECK:   vector.transfer_read %[[ALLOC]]
+
+// -----
+
+func.func @vector_barrier() -> vector<2xf32> {
+  %cst = arith.constant dense<0.0> : vector<2xf32>
+  %cst0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %alloc = bufferization.alloc_tensor() : tensor<2xf32>
+  %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
+  %read = vector.transfer_read %tmp[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
+  %barrier = iree_gpu.value_barrier %read : vector<2xf32>
+  return %barrier : vector<2xf32>
+}
+
+// Verify that the dual-modes of `value_barrier` are adhered to.
+// CHECK-LABEL: func @vector_barrier()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<2xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[ALLOC]]
+//  CHECK-NEXT:   %[[RD:.+]] = vector.transfer_read %[[ALLOC]]
+//  CHECK-NEXT:   iree_gpu.value_barrier %[[RD]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
index 3950ccc..2e106e1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -237,12 +237,4 @@
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TensorBarrierOp
-//===----------------------------------------------------------------------===//
-
-MutableOperandRange TensorBarrierOp::getDpsInitsMutable() {
-  return getInputMutable();
-}
-
 } // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 4c360a7..6b18e5c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -352,36 +352,39 @@
 }
 
 //===----------------------------------------------------------------------===//
-// TensorBarrierOp
+// ValueBarrierOp
 //===----------------------------------------------------------------------===//
 
-def IREEGPU_TensorBarrierOp : Op<IREEGPU_Dialect, "tensor_barrier", [
+def IREEGPU_ValueBarrierOp : Op<IREEGPU_Dialect, "value_barrier", [
     Pure,
-    DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
     AllTypesMatch<["input", "result"]>,
     ]> {
   let summary = "Shuffles a private tensor across a shared allocation";
   let description = [{
-    This operation acts as a barrier on a tensor value. It takes a single
-    tensor operand and produces an equivalent tensor. This does not have copy
-    and/or data movement semantics and simply represents a barrier on all writes
-    to the input tensor.
+    This operation acts as a barrier on a value semantic SSA value (tensor or
+    vector). It takes a single operand and produces a value equivalent to the
+    input. This does not have copy and/or data movement semantics and simply
+    represents a barrier on all writes in the tensor case, and a barrier until
+    all threads acquire the input vector in the vector case.
 
     This operation is a no-op when not present in a parallel context. This
     operation is pure as it only requires synchronization for the value it
     produces.
   }];
 
-  let arguments = (ins AnyRankedTensor:$input);
-  let results = (outs AnyRankedTensor:$result);
+  let arguments = (ins AnyRankedTensorOrVector:$input);
+  let results = (outs AnyRankedTensorOrVector:$result);
 
   let assemblyFormat = [{
     $input attr-dict `:` type($result)
   }];
 
   let extraClassDeclaration = [{
-    RankedTensorType getInputType() {
-      return getInput().getType();
+    bool hasTensorSemantics() {
+      return isa<::mlir::RankedTensorType>(getInput().getType());
+    }
+    ::mlir::ShapedType getInputType() {
+      return ::llvm::cast<::mlir::ShapedType>(getInput().getType());
     }
   }];
 }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index 08b2f17..e2744fb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -134,10 +134,21 @@
 // -----
 
 func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
-  %out = iree_gpu.tensor_barrier %input : tensor<?xf16>
+  %out = iree_gpu.value_barrier %input : tensor<?xf16>
   return %out : tensor<?xf16>
 }
 
 // CHECK-LABEL: func @tensor_barrier
 //  CHECK-SAME:   %[[INPUT:[A-Za-z0-9]+]]: tensor<?xf16>
-//       CHECK:   iree_gpu.tensor_barrier %[[INPUT]] : tensor<?xf16>
+//       CHECK:   iree_gpu.value_barrier %[[INPUT]] : tensor<?xf16>
+
+// -----
+
+func.func @vector_barrier(%input: vector<8xf16>) -> vector<8xf16> {
+  %out = iree_gpu.value_barrier %input : vector<8xf16>
+  return %out : vector<8xf16>
+}
+
+// CHECK-LABEL: func @vector_barrier
+//  CHECK-SAME:   %[[INPUT:[A-Za-z0-9]+]]: vector<8xf16>
+//       CHECK:   iree_gpu.value_barrier %[[INPUT]] : vector<8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
index f8cb914..8dcabd1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
@@ -26,6 +26,15 @@
 }
 
 //===---------------------------------------------------------------------===//
+// ApplyLowerValueBarrierOp
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  IREE::GPU::populateIREEGPULowerValueBarrierPatterns(patterns);
+}
+
+//===---------------------------------------------------------------------===//
 // ApplyUnrollMultiMmaOp
 //===---------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
index 08e13ec..dc69083 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -14,6 +14,19 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
+    "apply_patterns.iree.lower_value_barrier",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Populate patterns to convert value barriers on vectors into gpu.barrier ops.
+    Barriers on tensors are ignored.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyUnrollMultiMmaOp : Op<Transform_Dialect,
     "apply_patterns.iree.unroll_multi_mma",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
index 11c0c8e..f78f6d0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
@@ -18,6 +18,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "lower_vector_barrier.mlir",
             "transform_fuse_forall.mlir",
             "vectorize_multi_mma.mlir",
             "unroll_multi_mma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
index 686ce93..f3e2e40 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "lower_vector_barrier.mlir"
     "transform_fuse_forall.mlir"
     "unroll_multi_mma.mlir"
     "vectorize_multi_mma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir
new file mode 100644
index 0000000..b8391a3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+func.func @lower_value_barrier(%input: vector<4xf32>) -> vector<4xf32> {
+  %0 = iree_gpu.value_barrier %input : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.lower_value_barrier
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @lower_value_barrier
+//  CHECK-SAME:   %[[INPUT:[A-Za-z0-9]+]]: vector<4xf32>
+//  CHECK-NEXT:   gpu.barrier
+//  CHECK-NEXT:   return %[[INPUT]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
index 1071d91..4dc7621 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
@@ -24,9 +24,9 @@
 
 /// Bufferization of iree_gpu.tensor_barrier. Always just bufferizes in place
 /// and replaces with a barrier.
-struct TensorBarrierOpBufferizationInterface
+struct ValueBarrierOpBufferizationInterface
     : public BufferizableOpInterface::ExternalModel<
-          TensorBarrierOpBufferizationInterface, IREE::GPU::TensorBarrierOp> {
+          ValueBarrierOpBufferizationInterface, IREE::GPU::ValueBarrierOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     // This op never needs to bufferize to a copy.
@@ -47,8 +47,11 @@
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
-    auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+    auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
     assert(value == barrierOp.getResult() && "invalid value");
+    if (!barrierOp.hasTensorSemantics()) {
+      return failure();
+    }
     auto srcMemrefType = bufferization::getBufferType(barrierOp.getInput(),
                                                       options, invocationStack);
     if (failed(srcMemrefType))
@@ -58,7 +61,10 @@
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
-    auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+    auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
+    if (!barrierOp.hasTensorSemantics()) {
+      return failure();
+    }
     FailureOr<Value> buffer =
         getBuffer(rewriter, barrierOp.getInput(), options);
     if (failed(buffer)) {
@@ -78,8 +84,8 @@
 void registerIREEGPUBufferizationInterfaces(DialectRegistry &registry) {
   registry.addExtension(
       +[](MLIRContext *context, IREE::GPU::IREEGPUDialect *dialect) {
-        IREE::GPU::TensorBarrierOp::attachInterface<
-            TensorBarrierOpBufferizationInterface>(*context);
+        IREE::GPU::ValueBarrierOp::attachInterface<
+            ValueBarrierOpBufferizationInterface>(*context);
       });
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index a01ea45..5e982a8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -11,8 +11,6 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -406,4 +404,28 @@
   patterns.add<VectorizeStaticMultiMmaOpPattern>(patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// VectorBarrierOp Lowering
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LowerValueBarrierPattern
+    : public OpRewritePattern<IREE::GPU::ValueBarrierOp> {
+  using OpRewritePattern<IREE::GPU::ValueBarrierOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(IREE::GPU::ValueBarrierOp barrier,
+                                PatternRewriter &rewriter) const override {
+    if (barrier.hasTensorSemantics()) {
+      return failure();
+    }
+    rewriter.create<gpu::BarrierOp>(barrier.getLoc());
+    rewriter.replaceOp(barrier, barrier.getInput());
+    return success();
+  }
+};
+} // namespace
+
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) {
+  patterns.add<LowerValueBarrierPattern>(patterns.getContext());
+}
+
 } // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
index 56d1589..0e2afa3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -39,6 +39,8 @@
 
 void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);
 
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
+
 } // namespace mlir::iree_compiler::IREE::GPU
 
 #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_