[Codegen][GPU] Allow iree_gpu.barrier_region to take multiple operands/results (#18490)

The restriction to a single input and output was artificial as this op
simply represents synchronization on input and output values.
Additionally this removes the restriction on tensor/vector types, but
for the time being this op is still only used with those types.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
index 780fc09..0705895 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -32,18 +33,18 @@
 
 // Build a BarrierRegionOp with an empty.
 void BarrierRegionOp::build(OpBuilder &b, OperationState &result,
-                            Type resultType, Value dest) {
-  result.addOperands(dest);
+                            TypeRange resultTypes, ValueRange inputs) {
+  result.addOperands(inputs);
   (void)result.addRegion();
-  result.addTypes(resultType);
+  result.addTypes(resultTypes);
+  SmallVector<Location> blockArgLocs(inputs.size(), result.location);
 
   Region *region = result.regions[0].get();
 
   // `builder.createBlock` changes the insertion point within the block. Create
   // a guard to reset the insertion point of the builder after it is destroyed.
   OpBuilder::InsertionGuard guard(b);
-  b.createBlock(region, region->end(), ArrayRef<Type>{dest.getType()},
-                ArrayRef<Location>{result.location});
+  b.createBlock(region, region->end(), inputs.getTypes(), blockArgLocs);
 }
 
 LogicalResult BarrierRegionOp::verify() { return success(); }
@@ -51,19 +52,26 @@
 LogicalResult BarrierRegionOp::verifyRegions() {
   auto &region = getRegion();
   Block &block = region.front();
-  if (block.getNumArguments() != 1) {
-    return emitError("expected the block to have a single argument");
+  if (block.getNumArguments() != getNumOperands()) {
+    return emitError(
+        "expected the block argument count to match operand count");
   }
 
-  if (block.getArgumentTypes()[0] != getDestType()) {
-    return emitError("expected block to have single argument type of")
-           << getDestType();
+  if (!llvm::all_of_zip(block.getArgumentTypes(), getOperandTypes(),
+                        [](Type a, Type b) { return a == b; })) {
+    return emitError("expected block argument types to match operand types");
   }
 
   // Ensure that the region yields an element of the right type.
   auto yieldOp = llvm::cast<GPU::YieldOp>(block.getTerminator());
-  if (yieldOp.getValue().getType() != getResult().getType()) {
-    return emitOpError("expected yield type to match result type");
+  if (yieldOp->getNumOperands() != getNumResults()) {
+    return emitOpError(
+        "expected body to yield same number of values as results");
+  }
+
+  if (!llvm::all_of_zip(yieldOp->getOperandTypes(), getResultTypes(),
+                        [](Type a, Type b) { return a == b; })) {
+    return emitError("expected yielded value types to match result types");
   }
 
   return success();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 95f53fca..f2d9586 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -28,8 +28,8 @@
     ]> {
   let summary = "Synchronizes uses of a shared tensor.";
   let description = [{
-    This op is designed to represent synchronization of workers on a
-    particular shared tensor. This operation naturally arises when combining
+    This op is designed to represent synchronization of workers on the operands
+    and results of the given region. This operation naturally arises when combining
     the regions of producer-consumer `scf.forall` operations that share a
     mapping type.
 
@@ -58,27 +58,26 @@
 
     ```mlir
       %0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
-        %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
-        %in = ...
-        %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
-        %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
-        %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
         %alloc = bufferization.alloc_tensor {memory_space = #gpu.address_space<workgroup>}
           : tensor<4x128xf32>
-        %inserted_slice = tensor.insert_slice %in into %alloc[%2, %3] [2, 4] [1, 1]
-          : tensor<2x4xf32> to tensor<4x128xf32>
-        %slice = iree_gpu.barrier_region %inserted_slice {
+        %barrier = iree_gpu.barrier_region %alloc {
         ^bb0(%shared: tensor<4x128xf32>):
-          %slice = tensor.extract_slice %shared[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
+          %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
+          %in = ...
+          %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
+          %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
+          %inserted_slice = tensor.insert_slice %in into %shared[%2, %3] [2, 4] [1, 1]
+            : tensor<2x4xf32> to tensor<4x128xf32>
           iree_gpu.yield %slice : tensor<4x16xf32>
         } : tensor<4x128xf32> -> tensor<4x16xf32>
+        %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
+        %slice = tensor.extract_slice %barrier[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
         ...
       } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
     ```
 
-    A barrier_region can be lowered to two barriers, one on the |dest| input
-    operand, and a second one on the result. Note that the |dest| operand must
-    bufferize with memory space `#gpu.address_space<workgroup>`.
+    A barrier_region can be lowered to two barriers, one on the input operands
+    and a second one on the results.
 
     Movtivation and Intended Use Cases:
 
@@ -92,26 +91,20 @@
   }];
 
   let arguments = (ins
-    AnyRankedTensor:$dest
+    Variadic<AnyType>:$inputs
   );
   let regions = (region SizedRegion<1>:$region);
-  let results = (outs AnyRankedTensorOrVector:$result);
+  let results = (outs Variadic<AnyType>:$results);
 
   let assemblyFormat = [{
-    $dest $region attr-dict
-    `:` type($dest) `->` type($result)
+    (`ins` `(` $inputs^ `:` type($inputs) `)` )?
+    $region attr-dict `:` type($results)
   }];
 
   let builders = [
-    OpBuilder<(ins "Type":$result_type, "Value":$dest)>
+    OpBuilder<(ins "TypeRange":$result_types, "ValueRange":$inputs)>
   ];
 
-  let extraClassDeclaration = [{
-    RankedTensorType getDestType() {
-      return getDest().getType();
-    }
-  }];
-
   let skipDefaultBuilders = 1;
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
@@ -448,14 +441,16 @@
 def IREEGPU_YieldOp : Op<IREEGPU_Dialect, "yield", [
     Pure, ReturnLike, Terminator,
     HasParent<"::mlir::iree_compiler::IREE::GPU::BarrierRegionOp">]> {
-  let summary = "Yield a value from a region";
+  let summary = "Yield values from a region";
   let description = [{
-     This operation is used to yield a single value from a within a region.
+     This operation is used to yield values from a within a region.
   }];
 
-  let arguments = (ins AnyType:$value);
-  let assemblyFormat = "$value attr-dict `:` type($value)";
+  let arguments = (ins Variadic<AnyType>:$values);
   let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+  let assemblyFormat =
+      [{  attr-dict ($values^ `:` type($values))? }];
 }
 
 #endif // IREE_CODEGEN_DIALECT_IREEGPUOPS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index e706645..b174f8f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -1,39 +1,53 @@
 // RUN: iree-opt %s --split-input-file | FileCheck %s
 
 func.func @barrier_region(%init: tensor<6x6xf32>) -> tensor<3x2xf32> {
-  %0 = iree_gpu.barrier_region %init {
+  %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
   ^bb0(%intermediate: tensor<6x6xf32>):
     %slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
     iree_gpu.yield %slice : tensor<3x2xf32>
-  } : tensor<6x6xf32> -> tensor<3x2xf32>
+  } : tensor<3x2xf32>
   return %0 : tensor<3x2xf32>
 }
 
 // CHECK-LABEL: func @barrier_region
 //  CHECK-SAME:   %[[INIT:[A-Za-z0-9]+]]: tensor<6x6xf32>
-//       CHECK:   iree_gpu.barrier_region %[[INIT]] {
+//       CHECK:   iree_gpu.barrier_region ins(%[[INIT]] : tensor<6x6xf32>) {
 //       CHECK:     ^bb0(%[[INTERMEDIATE:.+]]: tensor<6x6xf32>):
 //       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
 //       CHECK:       iree_gpu.yield %[[SLICE]] : tensor<3x2xf32>
-//       CHECK:   } : tensor<6x6xf32> -> tensor<3x2xf32>
+//       CHECK:   } : tensor<3x2xf32>
 
 // -----
 
-func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> tensor<2x1x3x2xf32> {
-  %0 = iree_gpu.barrier_region %init {
+func.func @multi_result_barrier_region(%init: tensor<12x12xf32>) -> (tensor<2x1x3x2xf32>, index) {
+  %0:2 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
   ^bb0(%intermediate: tensor<12x12xf32>):
     %expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
     %slice = tensor.extract_slice %expand[0, 0, 0, 0] [2, 1, 3, 2] [1, 1, 1, 1] : tensor<4x3x3x4xf32> to tensor<2x1x3x2xf32>
-    iree_gpu.yield %slice : tensor<2x1x3x2xf32>
-  } : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
-  return %0 : tensor<2x1x3x2xf32>
+    %c0 = arith.constant 0 : index
+    iree_gpu.yield %slice, %c0 : tensor<2x1x3x2xf32>, index
+  } : tensor<2x1x3x2xf32>, index
+  return %0#0, %0#1 : tensor<2x1x3x2xf32>, index
 }
 
-// CHECK-LABEL: func @reshape_barrier_region
-//       CHECK:   iree_gpu.barrier_region
-//       CHECK:       tensor.expand_shape
-//       CHECK:       tensor.extract_slice
-//       CHECK:   } : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
+// CHECK-LABEL: func @multi_result_barrier_region
+//       CHECK:   %{{.*}}:2 = iree_gpu.barrier_region ins(%{{.*}} : tensor<12x12xf32>)
+//       CHECK:   } : tensor<2x1x3x2xf32>, index
+
+// -----
+
+func.func @multi_input_barrier_region(%x: index, %y: index) -> index {
+  %0 = iree_gpu.barrier_region ins(%x, %y : index, index) {
+  ^bb0(%ix: index, %iy: index):
+    %sum = arith.addi %ix, %iy : index
+    iree_gpu.yield %sum : index
+  } : index
+  return %0 : index
+}
+
+// CHECK-LABEL: func @multi_input_barrier_region
+//       CHECK:   %{{.*}} = iree_gpu.barrier_region ins(%{{.*}}, %{{.*}} : index, index)
+//       CHECK:   } : index
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
index b70768d..d1a3201 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
@@ -60,11 +60,11 @@
 //       CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1]
 //       CHECK:       scf.yield %[[INSERT]]
 
-//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
+//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
 //       CHECK:     ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
 //       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
 //       CHECK:       iree_gpu.yield %[[SLICE]]
-//       CHECK:     } : tensor<128x128xf32> -> tensor<16x16xf32>
+//       CHECK:     } : tensor<16x16xf32>
 //       CHECK:     %[[OUTSLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
 //       CHECK:     %[[MM:.+]] = linalg.matmul ins(%[[SHUFFLE]], %[[SHUFFLE]] : tensor<16x16xf32>, tensor<16x16xf32>)
 //  CHECK-SAME:       outs(%[[OUTSLICE]] : tensor<16x16xf32>) -> tensor<16x16xf32>
@@ -124,8 +124,8 @@
 //       CHECK:   scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
 //       CHECK:     %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
 //       CHECK:       %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
-//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
-//       CHECK:       } : tensor<128x128xf32> -> tensor<16x16xf32>
+//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
+//       CHECK:       } : tensor<16x16xf32>
 //       CHECK:   } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
 
 // -----
@@ -180,12 +180,12 @@
 //       CHECK:   scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
 //       CHECK:     %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
 //       CHECK:       %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
-//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
+//       CHECK:     %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
 //       CHECK:     ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
 //       CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
 //       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
 //       CHECK:       iree_gpu.yield %[[SLICE]]
-//       CHECK:       } : tensor<128x128xf32> -> tensor<16x16xf32>
+//       CHECK:       } : tensor<16x16xf32>
 //       CHECK:   } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
 
 // -----
@@ -253,8 +253,8 @@
 //       CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128]
 //       CHECK:         scf.yield %[[INSERT]]
 
-//       CHECK:       %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
-//       CHECK:       } : tensor<128x128xf32> -> tensor<16x16xf32>
+//       CHECK:       %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
+//       CHECK:       } : tensor<16x16xf32>
 //       CHECK:     } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
 //       CHECK:   } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
 
@@ -308,7 +308,7 @@
 //       CHECK:     %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
 //       CHECK:       %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
 //       CHECK:       scf.yield
-//       CHECK:     iree_gpu.barrier_region %[[LOOP]]
+//       CHECK:     iree_gpu.barrier_region ins(%[[LOOP]]
 //       CHECK:   } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
 
 // -----
@@ -363,5 +363,5 @@
 //       CHECK:     %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
 //       CHECK:       %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
 //       CHECK:       scf.yield
-//       CHECK:     iree_gpu.barrier_region %[[LOOP]]
+//       CHECK:     iree_gpu.barrier_region ins(%[[LOOP]]
 //       CHECK:   } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
index 8e901c3..9f2a811 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
@@ -1,11 +1,11 @@
 // RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
 
 func.func @barrier_region(%init: tensor<6x6xf32>, %x: index) -> tensor<3x2xf32> {
-  %0 = iree_gpu.barrier_region %init {
+  %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
   ^bb0(%intermediate: tensor<6x6xf32>):
     %slice = tensor.extract_slice %intermediate[0, %x] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
     iree_gpu.yield %slice : tensor<3x2xf32>
-  } : tensor<6x6xf32> -> tensor<3x2xf32>
+  } : tensor<3x2xf32>
   return %0 : tensor<3x2xf32>
 }
 
@@ -33,12 +33,12 @@
 func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> vector<2x1x3x2xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
-  %0 = iree_gpu.barrier_region %init {
+  %0 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
   ^bb0(%intermediate: tensor<12x12xf32>):
     %expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
     %read = vector.transfer_read %expand[%c0, %c0, %c0, %c0], %cst : tensor<4x3x3x4xf32>, vector<2x1x3x2xf32>
     iree_gpu.yield %read : vector<2x1x3x2xf32>
-  } : tensor<12x12xf32> -> vector<2x1x3x2xf32>
+  } : vector<2x1x3x2xf32>
   return %0 : vector<2x1x3x2xf32>
 }
 
@@ -59,3 +59,31 @@
 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[WRITE_BARRIER]]
 //       CHECK:   %[[READ:.+]] = vector.transfer_read %[[EXPAND]]
 //       CHECK:   %[[READ_BARRIER:.+]] = iree_gpu.value_barrier %[[READ]]
+
+// -----
+
+func.func @multi_barrier_region(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2xf32>) {
+  %0:2 = iree_gpu.barrier_region ins(%arg0, %arg1 : tensor<2xf32>, tensor<3xf32>) {
+  ^bb0(%in0: tensor<2xf32>, %in1: tensor<3xf32>):
+    iree_gpu.yield %in1, %in0 : tensor<3xf32>, tensor<2xf32>
+  } : tensor<3xf32>, tensor<2xf32>
+  return %0#0, %0#1 : tensor<3xf32>, tensor<2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.lower_barrier_region
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @multi_barrier_region
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor<3xf32>
+
+//       CHECK:   %[[WB:.+]]:2 = iree_gpu.value_barrier %[[ARG0]], %[[ARG1]]
+//       CHECK:   %[[RB:.+]]:2 = iree_gpu.value_barrier %[[WB]]#1, %[[WB]]#0
+//       CHECK:   return %[[RB]]#0, %[[RB]]#1 : tensor<3xf32>, tensor<2xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
index d7f8457..07bd94b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
@@ -75,11 +75,11 @@
 // -----
 
 func.func @barrier_region(%init: tensor<6x6xf32>) -> tensor<3x2xf32> {
-  %0 = iree_gpu.barrier_region %init {
+  %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
   ^bb0(%intermediate: tensor<6x6xf32>):
     %slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
     iree_gpu.yield %slice : tensor<3x2xf32>
-  } : tensor<6x6xf32> -> tensor<3x2xf32>
+  } : tensor<3x2xf32>
   return %0 : tensor<3x2xf32>
 }
 
@@ -99,6 +99,38 @@
 //       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
 //       CHECK:       %[[READ:.+]] = vector.transfer_read {{.*}} : tensor<3x2xf32>, vector<3x2xf32>
 //       CHECK:       iree_gpu.yield %[[READ]] : vector<3x2xf32>
-//       CHECK:   } : tensor<6x6xf32> -> vector<3x2xf32>
+//       CHECK:   } : vector<3x2xf32>
 //       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<3x2xf32>
 //       CHECK:   vector.transfer_write %[[SHUFFLE]], %[[EMPTY]]
+
+// -----
+
+func.func @multi_result_barrier_region(%init: tensor<6x6xf32>) -> (index, tensor<3x2xf32>) {
+  %0:2 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
+  ^bb0(%intermediate: tensor<6x6xf32>):
+    %slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
+    %c0 = arith.constant 0 : index
+    iree_gpu.yield %c0, %slice : index, tensor<3x2xf32>
+  } : index, tensor<3x2xf32>
+  return %0#0, %0#1 : index, tensor<3x2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.vectorize_iree_gpu
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @multi_result_barrier_region
+//       CHECK:   %[[SHUFFLE:.+]]:2 = iree_gpu.barrier_region
+//       CHECK:     ^bb0(%[[INTERMEDIATE:.+]]: tensor<6x6xf32>):
+//       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
+//       CHECK:       %[[READ:.+]] = vector.transfer_read {{.*}} : tensor<3x2xf32>, vector<3x2xf32>
+//       CHECK:       iree_gpu.yield %c0, %[[READ]] : index, vector<3x2xf32>
+//       CHECK:   } : index, vector<3x2xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<3x2xf32>
+//       CHECK:   vector.transfer_write %[[SHUFFLE]]#1, %[[EMPTY]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 60360a7..54e840b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -13,6 +13,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
@@ -34,6 +35,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 
 #define DEBUG_TYPE "iree-codegen-gpu-transforms"
 
@@ -128,8 +130,8 @@
   }
   (*consumerChain.begin())
       ->replaceUsesOfWith(source, barrierRegionOp.getBody()->getArgument(0));
-  rewriter.replaceAllUsesExcept(extractSlice.getResult(), barrierRegionOp,
-                                terminator);
+  rewriter.replaceAllUsesExcept(extractSlice.getResult(),
+                                barrierRegionOp.getResult(0), terminator);
 }
 
 LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
@@ -975,20 +977,19 @@
 
     // Step 1. Synchronize the workers on the shared dest.
     auto writeBarrier = rewriter.create<IREE::GPU::ValueBarrierOp>(
-        loc, barrierRegionOp.getDest());
+        loc, barrierRegionOp.getInputs());
 
     // Step 2. Inline the barrier op region.
     auto terminator = barrierRegionOp.getBody()->getTerminator();
-    Value replacement = terminator->getOperand(0);
     rewriter.inlineBlockBefore(barrierRegionOp.getBody(), barrierRegionOp,
-                               {writeBarrier.getResult(0)});
-    rewriter.setInsertionPointAfterValue(replacement);
-    Value barrier;
+                               writeBarrier.getResults());
+    rewriter.setInsertionPoint(terminator);
 
     // Step 3. Synchronize the result value.
-    barrier = rewriter.create<IREE::GPU::ValueBarrierOp>(loc, replacement)
-                  .getResult(0);
-    rewriter.replaceAllUsesWith(barrierRegionOp.getResult(), barrier);
+    auto barrier = rewriter.create<IREE::GPU::ValueBarrierOp>(
+        loc, terminator->getOperands());
+    rewriter.replaceAllUsesWith(barrierRegionOp.getResults(),
+                                barrier.getResults());
     rewriter.eraseOp(terminator);
     return success();
   }
@@ -1068,48 +1069,79 @@
 static LogicalResult
 vectorizeStaticBarrierRegionResult(RewriterBase &rewriter,
                                    IREE::GPU::BarrierRegionOp barrier) {
-  auto tensorResultType =
-      dyn_cast<RankedTensorType>(barrier.getResult().getType());
-  if (!tensorResultType || !tensorResultType.hasStaticShape()) {
+  SmallVector<Type> newResultTypes(barrier->getResultTypes());
+  llvm::SmallBitVector vectorizationTargets(newResultTypes.size(), false);
+  for (auto [i, type] : llvm::enumerate(newResultTypes)) {
+    auto tensorResultType = dyn_cast<RankedTensorType>(type);
+    if (!tensorResultType || !tensorResultType.hasStaticShape()) {
+      continue;
+    }
+    vectorizationTargets[i] = true;
+    VectorType newResultType = VectorType::get(
+        tensorResultType.getShape(), tensorResultType.getElementType());
+    type = newResultType;
+  }
+
+  if (vectorizationTargets.none()) {
     return failure();
   }
 
-  VectorType newResultType = VectorType::get(tensorResultType.getShape(),
-                                             tensorResultType.getElementType());
-
-  auto paddingValue = rewriter.create<arith::ConstantOp>(
-      barrier.getLoc(), rewriter.getZeroAttr(newResultType.getElementType()));
-
   auto newBarrier = rewriter.create<IREE::GPU::BarrierRegionOp>(
-      barrier.getLoc(), newResultType, barrier.getDest());
-
+      barrier.getLoc(), newResultTypes, barrier.getInputs());
   auto currentTerminator =
       cast<IREE::GPU::YieldOp>(barrier.getBody()->getTerminator());
+  rewriter.setInsertionPointToEnd(newBarrier.getBody());
   rewriter.mergeBlocks(barrier.getBody(), newBarrier.getBody(),
                        newBarrier.getBody()->getArguments());
-  rewriter.setInsertionPointToEnd(newBarrier.getBody());
 
-  auto innerRead = vector::createReadOrMaskedRead(
-      rewriter, currentTerminator.getLoc(), currentTerminator->getOperand(0),
-      newResultType.getShape(), paddingValue,
-      /*useInBoundsInsteadOfMasking=*/true);
-  rewriter.create<IREE::GPU::YieldOp>(currentTerminator->getLoc(), innerRead);
+  // Create the tensor -> vector conversions within the body of the new op.
+  SmallVector<Value> newYields = currentTerminator.getOperands();
+  for (auto [i, val] : llvm::enumerate(newYields)) {
+    if (!vectorizationTargets[i]) {
+      continue;
+    }
+
+    auto resultType = cast<VectorType>(newResultTypes[i]);
+    auto paddingValue = rewriter.create<arith::ConstantOp>(
+        barrier.getLoc(), rewriter.getZeroAttr(resultType.getElementType()));
+
+    auto innerRead =
+        vector::createReadOrMaskedRead(rewriter, currentTerminator.getLoc(),
+                                       val, resultType.getShape(), paddingValue,
+                                       /*useInBoundsInsteadOfMasking=*/true);
+    val = innerRead;
+  }
+
+  rewriter.create<IREE::GPU::YieldOp>(currentTerminator->getLoc(), newYields);
   rewriter.eraseOp(currentTerminator);
 
   rewriter.setInsertionPointAfter(newBarrier);
 
-  // Create the write back to a tensor.
-  auto empty = rewriter.create<tensor::EmptyOp>(
-      barrier.getLoc(), tensorResultType.getShape(),
-      tensorResultType.getElementType());
-  int64_t rank = tensorResultType.getRank();
-  auto zero = rewriter.create<arith::ConstantIndexOp>(barrier.getLoc(), 0);
-  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-      barrier,
-      /*vector=*/newBarrier,
-      /*source=*/empty,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/SmallVector<bool>(rank, true));
+  // Create the writes back to tensor types.
+  SmallVector<Value> replacements = newBarrier.getResults();
+  for (auto [i, val] : llvm::enumerate(replacements)) {
+    if (!vectorizationTargets[i]) {
+      continue;
+    }
+
+    auto tensorResultType =
+        cast<RankedTensorType>(barrier->getResultTypes()[i]);
+    auto empty = rewriter.create<tensor::EmptyOp>(
+        barrier.getLoc(), tensorResultType.getShape(),
+        tensorResultType.getElementType());
+    int64_t rank = tensorResultType.getRank();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(barrier.getLoc(), 0);
+    auto write = rewriter.create<vector::TransferWriteOp>(
+        barrier.getLoc(),
+        /*vector=*/val,
+        /*dest=*/empty,
+        /*indices=*/SmallVector<Value>(rank, zero),
+        /*inBounds=*/SmallVector<bool>(rank, true));
+    val = write->getResult(0);
+  }
+
+  rewriter.replaceOp(barrier, replacements);
+
   return success();
 }