Better support multidevice placement with `stream.async.barrier` (#19651)

Barriers / transfers should have semantics that attempt to parallelize
partitioning. If a value has a barrier placed it should divide
partitions to avoid spaning behavior with cross device dependencies.

Intermediate and ending transfers we want to place on the producing
partition so that any produced operator ends by producing the value at
the needed desetination

For incoming transfers we place in the destination partition as these
will not add a dependency on the incoming data.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 88ac020..ebbbfe2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -1159,6 +1159,13 @@
 }
 
 //===----------------------------------------------------------------------===//
+// flow.tensor.barrier
+//===----------------------------------------------------------------------===//
+
+void TensorBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {}
+
+//===----------------------------------------------------------------------===//
 // flow.tensor.transfer
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index e4732c3..df56f0d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1837,6 +1837,12 @@
 }
 
 //===----------------------------------------------------------------------===//
+// flow.tensor.barrier
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorBarrierOp::verify() { return success(); }
+
+//===----------------------------------------------------------------------===//
 // flow.tensor.transfer
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 98cdf5b..2666072 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1502,6 +1502,52 @@
   let hasFolder = 1;
 }
 
+def FLOW_TensorBarrierOp : FLOW_PureOp<"tensor.barrier", [
+  AllTypesMatch<["operand", "result"]>,
+  DeclareOpInterfaceMethods<Util_HoistableOpInterface>,
+  Util_ShapeAwareOp,
+]> {
+  let summary = [{}];
+  let description = [{
+  }];
+
+  let arguments = (ins
+    FLOW_Tensor:$operand,
+    FLOW_ShapeDynamicDims:$argument_dims,
+    AnyAttr:$target
+  );
+  let results = (outs
+    FLOW_Tensor:$result
+  );
+
+  let assemblyFormat = [{
+    $operand `:` type($result) (`{` $argument_dims^ `}`)?
+    `on` $target
+    attr-dict-with-keyword
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$operand, "Attribute":$target),
+    [{
+      build($_builder, $_state,
+          operand.getType(),
+          operand,
+          IREE::Util::buildDynamicDimsForValue($_state.location, operand, $_builder),
+          target);
+    }]>,
+  ];
+
+  let extraClassDeclaration = [{
+    bool isHoistableLeafOp() { return false; }
+
+    ValueRange getOperandDynamicDims(unsigned idx) { return getArgumentDims(); }
+    ValueRange getResultDynamicDims(unsigned idx) { return getArgumentDims(); }
+  }];
+
+  let hasVerifier = 1;
+  let hasCanonicalizer = 1;
+}
+
 def FLOW_TensorTransferOp : FLOW_PureOp<"tensor.transfer", [
   AllTypesMatch<["operand", "result"]>,
   DeclareOpInterfaceMethods<Util_HoistableOpInterface>,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
index 9cb1013..79d147a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
@@ -6,10 +6,12 @@
 
 #include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
 #include "iree/compiler/Dialect/Stream/Analysis/ResourceHazards.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -138,6 +140,8 @@
 
   auto asmState = getRootAsmState(block);
 
+  llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;
+
   for (auto &op : llvm::reverse(*block)) {
     // Skip constants; they just add noise (and since they are heavily CSE'd
     // they have lots of users to test).
@@ -163,6 +167,21 @@
       // Even though not a streamable op we still want to track it below.
     }
 
+    // Synchronizing operations should join with their producers if the producer
+    // is streamable.
+    if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
+        dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
+      auto producer = op.getOperand(0).getDefiningOp();
+      auto streamable =
+          dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
+      if (streamable) {
+        if (!syncOps.contains(producer))
+          syncOps[producer] = llvm::SmallVector<Operation *>();
+        syncOps[producer].push_back(&op);
+        continue;
+      }
+    }
+
     // Initialize op info for this op - whether streamable or not. We track
     // transitive hazards on each op. Note that thanks to the ordering of ops
     // in SSA form (_reversed here!_) we know that once we visit this op no
@@ -202,6 +221,21 @@
       opInfo.hazards |= userInfo.membership;
       opInfo.hazards |= userInfo.hazards;
     }
+
+    for (auto syncOp : syncOps[&op]) {
+      for (auto user : syncOp->getUsers()) {
+        auto userInfoIt = opInfos.find(user);
+        if (userInfoIt == opInfos.end())
+          continue;
+        auto &userInfo = userInfoIt->second;
+        opInfo.hazards |= userInfo.membership;
+        opInfo.hazards |= userInfo.hazards;
+        consumers.reset();
+      }
+    }
+
+    // For any sync ops not use this ops results we need to put in a
+    // non-consumer block:
     llvm::BitVector candidates(builders.size(), /*t=*/true);
     candidates ^= opInfo.hazards;
     candidates |= consumers;
@@ -216,6 +250,16 @@
       }
     }
 
+    for (auto syncOp : syncOps[&op]) {
+      for (auto ordinal : candidates.set_bits()) {
+        if (!canAddOpToPartition(*syncOp, opInfo, ordinal)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Candidate partition " << ordinal << " incompatible\n");
+          candidates.reset(ordinal);
+        }
+      }
+    }
+
     // If this op is not streamable then bail here; we've still setup the hazard
     // map for following iteration.
     auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
@@ -227,63 +271,60 @@
     // First see which partitions are consuming this that we can also safely
     // move in to.
     consumers &= candidates;
+    if (consumers.any())
+      candidates = consumers;
 
     opInfo.membership.reserve(builders.size() + 1);
     opInfo.membership.resize(builders.size(), /*t=*/false);
 
-    // If we have one or more consumers we should go into those first.
-    if (consumers.any()) {
-      // If we are a clonable op (like splat) clone us into every partition.
-      // Otherwise we just pick the first we find (probably a bad heuristic).
-      if (streamableOp.preferCloneToConsumers() && consumers.count() > 1) {
-        for (auto consumerOrdinal : consumers.set_bits()) {
-          LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
-                                  << consumerOrdinal << "\n");
-          auto &consumerBuilder = builders[consumerOrdinal];
-          consumerBuilder->insert(&op, opInfo);
-          consumerBuilder->clonedOps.insert(&op);
-        }
-      } else {
-        int consumerOrdinal = consumers.find_last();
-        LLVM_DEBUG(llvm::dbgs() << "Moving into consumer partition "
+    // No consumers - if there's any candidate then we'll go into that.
+    int firstCandidateOrdinal = candidates.find_first();
+    if (firstCandidateOrdinal == -1) {
+      // Mark the op as having hazards against all other partitions.
+      // It is better to be safe than incorrect, especially with our current
+      // minimal test coverage. It's not always safe to reorder things - if
+      // anything we are unlikely to be conservative enough here - for example,
+      // if there's a stream.resource.load of a resource or a global we can't
+      // move anything that may affect that resource or global. This
+      // partitioning was designed to be conservative because debugging such
+      // issues is really difficult.
+      if (!builders.empty()) {
+        opInfo.hazards.set(0, builders.size() - 1);
+      }
+
+      // Create a new partition just for this op.
+      opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
+      auto builder = std::make_unique<PartitionBuilder>();
+      builder->ordinal = builders.size();
+      builders.push_back(std::move(builder));
+      usableBuilders.resize(builders.size(), /*t=*/true);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Created partition " << builder->ordinal << "\n");
+      firstCandidateOrdinal = builders.size() - 1;
+    }
+
+    auto &builder = builders[firstCandidateOrdinal];
+
+    // If we have synchronization operations we can place in the last block:
+    for (auto syncOp : syncOps[&op]) {
+      builder->insert(syncOp, opInfo);
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
+                            << firstCandidateOrdinal << " (continue)\n");
+    // If we are a clonable op (like splat) clone us into every partition.
+    // Otherwise we just pick the first we find (probably a bad heuristic).
+    if (consumers.count() > 1 && streamableOp.preferCloneToConsumers()) {
+      for (auto consumerOrdinal : consumers.set_bits()) {
+        LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
                                 << consumerOrdinal << "\n");
         auto &consumerBuilder = builders[consumerOrdinal];
         consumerBuilder->insert(&op, opInfo);
+        consumerBuilder->clonedOps.insert(&op);
       }
-      LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n");
-      continue;
+    } else {
+      builder->insert(&op, opInfo);
     }
-
-    // No consumers - if there's any candidate then we'll go into that.
-    int firstCandidateOrdinal = candidates.find_first();
-    if (firstCandidateOrdinal != -1) {
-      LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
-                              << firstCandidateOrdinal << " (continue)\n");
-      builders[firstCandidateOrdinal]->insert(&op, opInfo);
-      continue;
-    }
-
-    // Mark the op as having hazards against all other partitions.
-    // It is better to be safe than incorrect, especially with our current
-    // minimal test coverage. It's not always safe to reorder things - if
-    // anything we are unlikely to be conservative enough here - for example,
-    // if there's a stream.resource.load of a resource or a global we can't
-    // move anything that may affect that resource or global. This partitioning
-    // was designed to be conservative because debugging such issues is really
-    // difficult.
-    if (!builders.empty()) {
-      opInfo.hazards.set(0, builders.size() - 1);
-    }
-
-    // Create a new partition just for this op.
-    opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
-    auto builder = std::make_unique<PartitionBuilder>();
-    builder->ordinal = builders.size();
-    builder->insert(&op, opInfo);
-    LLVM_DEBUG(llvm::dbgs()
-               << "Created partition " << builder->ordinal << "\n");
-    builders.push_back(std::move(builder));
-    usableBuilders.resize(builders.size(), /*t=*/true);
   }
 
   // Ops cloned into multiple partitions may still escape if there are
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 44c8a46..02939d8 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -237,6 +237,24 @@
   }
 };
 
+struct ConvertTensorBarrierOp
+    : public AffinityOpConversionPattern<IREE::Flow::TensorBarrierOp> {
+  using AffinityOpConversionPattern::AffinityOpConversionPattern;
+  LogicalResult matchAndRewriteOnAffinity(
+      IREE::Flow::TensorBarrierOp op, OneToNOpAdaptor adaptor,
+      IREE::Stream::AffinityAttr executionAffinityAttr,
+      ConversionPatternRewriter &rewriter) const override {
+    auto operand = resolveTensorOperands(op.getLoc(), op.getOperand(),
+                                         adaptor.getOperand(), rewriter);
+    auto barrierOp = rewriter.create<IREE::Stream::AsyncBarrierOp>(
+        op.getLoc(), operand.resource.getType(), operand.resource,
+        operand.resourceSize,
+        /*affinity=*/operand.affinity);
+    rewriter.replaceOpWithMultiple(op, {{barrierOp, operand.resourceSize}});
+    return success();
+  }
+};
+
 struct ConvertTensorTransferOp
     : public AffinityOpConversionPattern<IREE::Flow::TensorTransferOp> {
   using AffinityOpConversionPattern::AffinityOpConversionPattern;
@@ -1162,15 +1180,15 @@
     MLIRContext *context, TypeConverter &typeConverter,
     IREE::Stream::AffinityAnalysis *affinityAnalysis,
     RewritePatternSet &patterns) {
-  patterns
-      .insert<ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
-              ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
-              ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
-              ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
-              ConvertTensorCloneOp, ConvertTensorTransferOp,
-              ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
-              ConvertTensorStoreOp, ConvertTensorTraceOp>(
-          typeConverter, context, affinityAnalysis);
+  patterns.insert<
+      ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
+      ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
+      ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
+      ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
+      ConvertTensorCloneOp, ConvertTensorBarrierOp, ConvertTensorTransferOp,
+      ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
+      ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter, context,
+                                                  affinityAnalysis);
   patterns.insert<ConvertChannelDefaultOp>(typeConverter, context,
                                            affinityAnalysis);
   patterns.insert<ConvertChannelSplitOp, ConvertChannelRankOp,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index ee68211..4fb2216 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -149,6 +149,19 @@
 
 // -----
 
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @tensorBarrier
+//  CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index)
+util.func public @tensorBarrier(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
+  // CHECK: %[[TRANSFER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> !stream.resource<*>
+  %transfer = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
+  // CHECK: util.return %[[TRANSFER]], %[[INPUT_SIZE]]
+  util.return %transfer : tensor<?x128xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @tensorSlice
 //  CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
 util.func public @tensorSlice(%input : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index b14bb08..0f4ef35 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1965,6 +1965,13 @@
 }
 
 //===----------------------------------------------------------------------===//
+// stream.async.barrier
+//===----------------------------------------------------------------------===//
+
+void AsyncBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                 MLIRContext *context) {}
+
+//===----------------------------------------------------------------------===//
 // stream.async.transfer
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 3c89895..c3bf0cf 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2007,6 +2007,14 @@
 }
 
 //===----------------------------------------------------------------------===//
+// stream.async.barrier
+//===----------------------------------------------------------------------===//
+
+bool AsyncBarrierOp::isMetadata() { return true; }
+
+LogicalResult AsyncBarrierOp::verify() { return success(); }
+
+//===----------------------------------------------------------------------===//
 // stream.async.transfer
 //===----------------------------------------------------------------------===//
 
@@ -2026,15 +2034,17 @@
       resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
     // TODO(multi-device): figure out how to model staging->staging transfers.
     return getSourceAffinityAttr();
-  } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+  } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::External ||
+             sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
     // If source is staging then the op should execute on the consumer.
     return getResultAffinityAttr();
-  } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+  } else if (resultType.getLifetime() == IREE::Stream::Lifetime::External ||
+             resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
     // If result is staging then the op should execute on the producer.
     return getSourceAffinityAttr();
   } else {
     // Default to result affinity.
-    return getResultAffinityAttr();
+    return getSourceAffinityAttr();
   }
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index d499dd0..768cbde 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -2216,6 +2216,51 @@
   let hasCanonicalizer = 1;
 }
 
+def Stream_AsyncBarrierOp : Stream_Op<"async.barrier", [
+  Stream_AffinityOp,
+  Stream_AsyncPhaseOp,
+  DeclareOpInterfaceMethods<Stream_StreamableOp, [
+    "isMetadata",
+  ]>,
+  Util_SizeAwareOp,
+]> {
+  let summary = [{ }];
+  let description = [{
+  }];
+
+  let arguments = (ins
+    AnyTypeOf<[
+      Stream_AnyStreamResource,
+      Stream_StagingResource,
+    ]>:$source,
+    Stream_Size:$size,
+    OptionalAttr<Stream_AffinityAttr>:$affinity
+  );
+  let results = (outs
+    AnyTypeOf<[
+      Stream_AnyStreamResource,
+      Stream_StagingResource,
+    ]>:$result
+  );
+
+  let assemblyFormat = [{
+    $source `:` type($source)
+    `` `{` $size `}`
+    (`from` `(` $affinity^ `)`)?
+    `->`
+    type($result)
+    attr-dict-with-keyword
+  }];
+
+  let extraClassDeclaration = [{
+    Value getOperandSize(unsigned idx) { return getSize(); }
+    Value getResultSize(unsigned idx) { return getSize(); }
+  }];
+
+  let hasVerifier = 1;
+  let hasCanonicalizer = 1;
+}
+
 def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
   DeclareOpInterfaceMethods<Stream_AffinityOp, [
     "getAffinityAttr",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
index 123f090..525b8c4 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
@@ -427,6 +427,7 @@
                   ApplyStreamableOp<IREE::Stream::AsyncUpdateOp>,
                   ApplyStreamableOp<IREE::Stream::AsyncCopyOp>,
                   ApplyStreamableOp<IREE::Stream::AsyncCollectiveOp>,
+                  ApplyStreamableOp<IREE::Stream::AsyncBarrierOp>,
                   ApplyStreamableOp<IREE::Stream::AsyncTransferOp>,
                   ApplyStreamableOp<IREE::Stream::AsyncLoadOp>,
                   ApplyStreamableOp<IREE::Stream::AsyncStoreOp>,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 43f5269..20ee572 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -665,6 +665,13 @@
   return success();
 }
 
+static LogicalResult applyAsyncBarrierOp(IREE::Stream::AsyncBarrierOp barrierOp,
+                                         AllocationScope &scope,
+                                         OpBuilder builder) {
+  barrierOp.erase();
+  return success();
+}
+
 static LogicalResult applyAsyncTransferOp(IREE::Stream::AsyncTransferOp asyncOp,
                                           AllocationScope &scope,
                                           OpBuilder builder) {
@@ -987,6 +994,9 @@
                    .Case([&](IREE::Stream::AsyncCollectiveOp op) {
                      return applyAsyncCollectiveOp(op, scope, OpBuilder(op));
                    })
+                   .Case([&](IREE::Stream::AsyncBarrierOp op) {
+                     return applyAsyncBarrierOp(op, scope, OpBuilder(op));
+                   })
                    .Case([&](IREE::Stream::AsyncTransferOp op) {
                      return applyAsyncTransferOp(op, scope, OpBuilder(op));
                    })
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
index b5ecc53..d913525 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
@@ -34,6 +34,55 @@
 
 // -----
 
+// Tests partitioning multi device execution with barriers and transfers.
+// It validates that multi stream commands are created and run in parallel:
+
+// CHECK-LABEL: util.func public @deviceMultiDeviceSync
+util.func public @deviceMultiDeviceSync(%arg0: i1) -> !stream.resource<transient> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c255_i32 = arith.constant 255 : i32
+  %0 = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c128}
+  %1 = stream.async.dispatch on(#hal.device.affinity<@__device_0>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
+  %3 = stream.async.barrier %1 : !stream.resource<transient>{%c128} -> !stream.resource<transient>
+  %4 = stream.async.transfer %1 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@__device_0>) -> to(#hal.device.affinity<@__device_1>) !stream.resource<transient>{%c128}
+  // CHECK: stream.async.execute on(#hal.device.affinity<@__device_0>)
+  // CHECK: stream.async.splat
+  // CHECK: stream.async.dispatch
+  // CHECK: stream.async.barrier
+  // CHECK: stream.async.transfer
+
+  %2 = stream.async.dispatch on(#hal.device.affinity<@__device_1>) @ex::@dispatch1[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
+  %5 = stream.async.barrier %2 : !stream.resource<transient>{%c128} -> !stream.resource<transient>
+  %6 = stream.async.transfer %2 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@__device_1>) -> to(#hal.device.affinity<@__device_0>) !stream.resource<transient>{%c128}
+  // CHECK: stream.async.execute on(#hal.device.affinity<@__device_1>)
+  // CHECK: stream.async.splat
+  // CHECK: stream.async.dispatch
+  // CHECK: stream.async.barrier
+  // CHECK: stream.async.transfer
+
+  %7 = stream.async.dispatch on(#hal.device.affinity<@__device_0>) @ex::@dispatch2[%c1, %c1, %c1](%3[%c0 to %c128 for %c128], %6[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
+  %9 = stream.async.barrier %7 : !stream.resource<transient>{%c128} -> !stream.resource<transient>
+  // CHECK: stream.async.execute on(#hal.device.affinity<@__device_0>)
+  // CHECK: stream.async.dispatch
+  // CHECK: stream.async.barrier
+
+  %8 = stream.async.dispatch on(#hal.device.affinity<@__device_1>) @ex::@dispatch2[%c1, %c1, %c1](%4[%c0 to %c128 for %c128], %5[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
+  %10 = stream.async.transfer %8 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@__device_1>) -> to(#hal.device.affinity<@__device_0>) !stream.resource<transient>{%c128}
+  // CHECK: stream.async.execute on(#hal.device.affinity<@__device_1>)
+  // CHECK: stream.async.dispatch
+  // CHECK: stream.async.transfer
+
+  %11 = stream.async.dispatch on(#hal.device.affinity<@__device_0>) @ex::@dispatch2[%c1, %c1, %c1](%9[%c0 to %c128 for %c128], %10[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
+  // CHECK: stream.async.execute on(#hal.device.affinity<@__device_0>)
+  // CHECK: stream.async.dispatch
+
+  util.return %11 : !stream.resource<transient>
+}
+
+// -----
+
 // Tests basic partitioning of sequential dispatches with differing affinities.
 // Dispatches with the same affinities should be placed into the same execution
 // regions.
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
index ab1adf0..763e219 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
@@ -47,6 +47,26 @@
   }
 };
 
+struct FlowBarrierTargetAffinityAttrExternalModel
+    : public IREE::Stream::AffinityOpInterface::ExternalModel<
+          FlowBarrierTargetAffinityAttrExternalModel,
+          IREE::Flow::TensorBarrierOp> {
+  static void add(MLIRContext *context) {
+    IREE::Flow::TensorBarrierOp::attachInterface<
+        FlowBarrierTargetAffinityAttrExternalModel>(*context);
+  }
+
+  bool requiresAffinity(Operation *op) const { return true; }
+
+  IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
+    return op->getAttrOfType<IREE::Stream::AffinityAttr>("target");
+  }
+
+  void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+    op->setAttr("target", value);
+  }
+};
+
 struct FlowTransferTargetAffinityAttrExternalModel
     : public IREE::Stream::AffinityOpInterface::ExternalModel<
           FlowTransferTargetAffinityAttrExternalModel,
@@ -173,6 +193,7 @@
   registry.insert<IREE::Flow::FlowDialect>();
   registry.addExtension(+[](MLIRContext *context,
                             IREE::Flow::FlowDialect *dialect) {
+    FlowBarrierTargetAffinityAttrExternalModel::add(context);
     FlowTransferTargetAffinityAttrExternalModel::add(context);
     AffinityOpAttrExternalModel<IREE::Flow::DispatchRegionOp>::add(context);
     AffinityOpAttrExternalModel<IREE::Flow::DispatchWorkgroupsOp>::add(context);