Support mhlo collective ops (#11988)

Add the support for the following mhlo collective ops:
- mhlo.replica_id
- mhlo.all_gather
- mhlo.all_reduce
- mhlo.reduce_scatter

Since NCCL only supports the splitting and concatenation on dim 0 for
all_gather and reduce_scatter, transposes are inserted when the
split/concat dimension is not 0.

To make the implementation simple and incremental, several stages are
planned as follows:

Stage 1 (The current PR):

It assumes a deterministic order of collective operations with 1:1
mapping from replica_id to rank. This means that there is a single
replica group in the mhlo operation and all replicas participate in the
collective operation. Since the order is deterministic and all ranks are
involved in the communication, we can simply use the default channel for
communication. In the MHLO ops, `use_global_device_ids` is set to use
the flattened IDs.

Note that the MHLO collective ops have multiple strategies to interpret
`replica_groups` attribute, such as `flattened_ids`, `cross_replica`,
`cross_partition`, and `cross_replica_and_partition`. (See
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#parallel-execution
for more details of the strategies.)

Stage 2:

Supports multiple channels. This allows us to have multiple replica
groups for collective operations.

Stage 3:

Supports `partition_id` the other strategies: `cross_replica`,
`cross_partition`, and `cross_replica_and_partition`.

Stage 4:

Supports `all_to_all` and `collective_permute`. This would need to
support the NCCL group markers to support multiple collective ops in
parallel since the ops are composite and will need to be lowered into
the existing collective ops.

Stage 5:
PJRT integration and model level testing.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
index 456e6c1..1f6e1b6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -156,4 +156,82 @@
   let mnemonic = "dummy";
 }
 
+//===----------------------------------------------------------------------===//
+// Flow enums
+//===----------------------------------------------------------------------===//
+
+def FLOW_CollectiveElementType_Sint8 : I32EnumAttrCase<"Sint8", 0, "si8">;
+def FLOW_CollectiveElementType_Uint8 : I32EnumAttrCase<"Uint8", 1, "ui8">;
+def FLOW_CollectiveElementType_Sint16 : I32EnumAttrCase<"Sint16", 2, "si16">;
+def FLOW_CollectiveElementType_Uint16 : I32EnumAttrCase<"Uint16", 3, "ui16">;
+def FLOW_CollectiveElementType_Sint32 : I32EnumAttrCase<"Sint32", 4, "si32">;
+def FLOW_CollectiveElementType_Uint32 : I32EnumAttrCase<"Uint32", 5, "ui32">;
+def FLOW_CollectiveElementType_Sint64 : I32EnumAttrCase<"Sint64", 6, "si64">;
+def FLOW_CollectiveElementType_Uint64 : I32EnumAttrCase<"Uint64", 7, "ui64">;
+def FLOW_CollectiveElementType_Float16 : I32EnumAttrCase<"Float16", 8, "f16">;
+def FLOW_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">;
+def FLOW_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">;
+def FLOW_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">;
+def FLOW_CollectiveElementTypeAttr :
+    I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [
+      FLOW_CollectiveElementType_Sint8,
+      FLOW_CollectiveElementType_Uint8,
+      FLOW_CollectiveElementType_Sint16,
+      FLOW_CollectiveElementType_Uint16,
+      FLOW_CollectiveElementType_Sint32,
+      FLOW_CollectiveElementType_Uint32,
+      FLOW_CollectiveElementType_Sint64,
+      FLOW_CollectiveElementType_Uint64,
+      FLOW_CollectiveElementType_Float16,
+      FLOW_CollectiveElementType_Float32,
+      FLOW_CollectiveElementType_Float64,
+      FLOW_CollectiveElementType_BFloat16,
+    ]> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Flow";
+}
+
+//===----------------------------------------------------------------------===//
+// Flow channel type
+//===----------------------------------------------------------------------===//
+
+def FLOW_Channel : TypeDef<Flow_Dialect, "Channel", []> {
+  let mnemonic = "channel";
+  let summary = [{a collecive communication channel}];
+  let description = [{
+    Represents a single participant in a collective clique. Multiple channels
+    may exist within the same program to allow for partial operations or
+    hierarchical operations.
+
+    In programs that have already been partitioned prior to being compiled there
+    will often exist only one channel and `flow.channel.default` can be used
+    to reference it. In programs that model SPMD behavior internally channels
+    can be created or provided by hosting applications.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Flow collective reduction op
+//===----------------------------------------------------------------------===//
+
+// NOTE: the enum values must exactly match with the corresponding enum values
+// of the Stream reduction op.
+
+def FLOW_CollectiveReductionOp_None             : I32EnumAttrCase<"None", 0, "none">;
+def FLOW_CollectiveReductionOp_ReductionSum     : I32EnumAttrCase<"ReductionSum", 1, "sum">;
+def FLOW_CollectiveReductionOp_ReductionProduct : I32EnumAttrCase<"ReductionProduct", 2, "product">;
+def FLOW_CollectiveReductionOp_ReductionMinimum : I32EnumAttrCase<"ReductionMinimum", 3, "minimum">;
+def FLOW_CollectiveReductionOp_ReductionMaximum : I32EnumAttrCase<"ReductionMaximum", 4, "maximum">;
+def FLOW_CollectiveReductionOp_ReductionAverage : I32EnumAttrCase<"ReductionAverage", 5, "average">;
+def FLOW_CollectiveReductionOpAttr :
+    I32EnumAttr<"CollectiveReductionOp", "valid CollectiveReductionOp", [
+      FLOW_CollectiveReductionOp_None,
+      FLOW_CollectiveReductionOp_ReductionSum,
+      FLOW_CollectiveReductionOp_ReductionProduct,
+      FLOW_CollectiveReductionOp_ReductionMinimum,
+      FLOW_CollectiveReductionOp_ReductionMaximum,
+      FLOW_CollectiveReductionOp_ReductionAverage,
+    ]> {
+  let cppNamespace = "mlir::iree_compiler::IREE::Flow";
+}
+
 #endif  // IREE_DIALECT_FLOW_BASE
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 244dbd4..46596ff 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1545,6 +1545,118 @@
           context);
 }
 
+//===----------------------------------------------------------------------===//
+// flow.channel.count
+//===----------------------------------------------------------------------===//
+
+void ChannelCountOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "channel_count");
+}
+
+//===----------------------------------------------------------------------===//
+// flow.channel.default
+//===----------------------------------------------------------------------===//
+
+void ChannelDefaultOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "channel_default");
+}
+
+//===----------------------------------------------------------------------===//
+// flow.channel.rank
+//===----------------------------------------------------------------------===//
+
+void ChannelRankOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "channel_rank");
+}
+
+//===----------------------------------------------------------------------===//
+// flow.collective.all_gather
+//===----------------------------------------------------------------------===//
+
+Value CollectiveAllGatherOp::getTiedResult(unsigned resultIndex) {
+  return IREE::Util::TiedOpInterface::findTiedBaseValue(getTarget());
+}
+
+::llvm::Optional<unsigned> CollectiveAllGatherOp::getTiedResultOperandIndex(
+    unsigned resultIndex) {
+  return {0};  // target
+}
+
+SmallVector<int64_t, 4> CollectiveAllGatherOp::getTiedResultOperandIndices() {
+  return {0};  // target
+}
+
+void CollectiveAllGatherOp::build(OpBuilder &builder, OperationState &state,
+                                  CollectiveElementTypeAttr elementType,
+                                  Value target, Value source, Value channel) {
+  auto targetDims =
+      IREE::Util::buildDynamicDimsForValue(state.location, target, builder);
+
+  build(builder, state, elementType, target, targetDims, source, channel,
+        builder.getIndexArrayAttr({0}));
+}
+
+//===----------------------------------------------------------------------===//
+// flow.collective.all_reduce
+//===----------------------------------------------------------------------===//
+
+Value CollectiveAllReduceOp::getTiedResult(unsigned resultIndex) {
+  return IREE::Util::TiedOpInterface::findTiedBaseValue(getTarget());
+}
+
+::llvm::Optional<unsigned> CollectiveAllReduceOp::getTiedResultOperandIndex(
+    unsigned resultIndex) {
+  return {0};  // target
+}
+
+SmallVector<int64_t, 4> CollectiveAllReduceOp::getTiedResultOperandIndices() {
+  return {0};  // target
+}
+
+void CollectiveAllReduceOp::build(OpBuilder &builder, OperationState &state,
+                                  CollectiveReductionOpAttr reductionOp,
+                                  CollectiveElementTypeAttr elementType,
+                                  Value target, Value source, Value channel) {
+  auto targetDims =
+      IREE::Util::buildDynamicDimsForValue(state.location, target, builder);
+
+  build(builder, state, reductionOp, elementType, target, targetDims, source,
+        channel, builder.getIndexArrayAttr({0}));
+}
+
+//===----------------------------------------------------------------------===//
+// flow.collective.reduce_scatter
+//===----------------------------------------------------------------------===//
+
+Value CollectiveReduceScatterOp::getTiedResult(unsigned resultIndex) {
+  return IREE::Util::TiedOpInterface::findTiedBaseValue(getTarget());
+}
+
+::llvm::Optional<unsigned> CollectiveReduceScatterOp::getTiedResultOperandIndex(
+    unsigned resultIndex) {
+  return {0};  // target
+}
+
+SmallVector<int64_t, 4>
+CollectiveReduceScatterOp::getTiedResultOperandIndices() {
+  return {0};  // target
+}
+
+void CollectiveReduceScatterOp::build(OpBuilder &builder, OperationState &state,
+                                      CollectiveReductionOpAttr reductionOp,
+                                      CollectiveElementTypeAttr elementType,
+                                      Value target, Value source,
+                                      Value channel) {
+  auto targetDims =
+      IREE::Util::buildDynamicDimsForValue(state.location, target, builder);
+
+  build(builder, state, reductionOp, elementType, target, targetDims, source,
+        channel, builder.getIndexArrayAttr({0}));
+}
+
 }  // namespace Flow
 }  // namespace IREE
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 5e1acab..e6f8c29 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1304,4 +1304,185 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Collective communication ops
+//===----------------------------------------------------------------------===//
+
+def FLOW_ChannelDefaultOp : FLOW_Op<"channel.default", [
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = [{returns a default collective communication channel}];
+  let description = [{
+    Returns a channel initialized using the runtime environment.
+  }];
+
+  let results = (outs
+    FLOW_Channel:$result
+  );
+
+  let assemblyFormat = [{
+    `:` type($result)
+    attr-dict-with-keyword
+  }];
+}
+
+def FLOW_ChannelRankOp : FLOW_Op<"channel.rank", [
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+  let summary = [{returns the rank of the local participant in the group}];
+  let description = [{
+    Returns the rank the channel represents as a participant in a collective
+    group in `[0, count)`.
+  }];
+
+  let arguments = (ins
+    FLOW_Channel:$channel
+  );
+  let results = (outs
+    Index:$result
+  );
+
+  let assemblyFormat = [{
+     $channel `:` type($result)
+    attr-dict-with-keyword
+  }];
+}
+
+def FLOW_ChannelCountOp : FLOW_Op<"channel.count", [
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+  let summary = [{returns the total number of participants in the group}];
+  let description = [{
+    Returns the total participant count in the collective communicator group.
+  }];
+
+  let arguments = (ins
+    FLOW_Channel:$channel
+  );
+  let results = (outs
+    Index:$result
+  );
+
+  let assemblyFormat = [{
+    $channel `:` type($result)
+    attr-dict-with-keyword
+  }];
+}
+
+def FLOW_CollectiveAllGatherOp : FLOW_Op<"collective.all_gather", [
+  AllTypesMatch<["target", "result"]>,
+  DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+    "getTiedResult",
+    "getTiedResultOperandIndex",
+    "getTiedResultOperandIndices",
+  ]>,
+]> {
+  let summary = [{performs all-gather operation}];
+  let description = [{It gathers data from all ranks and concatenates them on the 0-th dimension.}];
+
+  let arguments = (ins
+    FLOW_CollectiveElementTypeAttr:$element_type,
+    FLOW_Tensor:$target,
+    FLOW_ShapeDynamicDims:$target_dims,
+    FLOW_Tensor:$source,
+    FLOW_Channel:$channel,
+    OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands
+  );
+  let results = (outs
+    FLOW_Tensor:$result
+  );
+  let assemblyFormat = [{
+    $element_type `,` $target `,` $source `,` $channel `:`
+    `(` type($target) `,` type($source) `,` type($channel) `)` `->`
+    custom<ShapedTiedResult>(type($result), $target_dims, $tied_operands)
+    attr-dict-with-keyword
+  }];
+  let builders = [
+    OpBuilder<(ins
+      "CollectiveElementTypeAttr":$element_type,
+      "Value":$target,
+      "Value":$source,
+      "Value":$channel)>,
+  ];
+}
+
+def FLOW_CollectiveAllReduceOp : FLOW_Op<"collective.all_reduce", [
+  AllTypesMatch<["source", "target", "result"]>,
+  DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+    "getTiedResult",
+    "getTiedResultOperandIndex",
+    "getTiedResultOperandIndices",
+  ]>,
+]> {
+  let summary = [{performs all-reduce operation}];
+  let description = [{The operation reduces data across all the ranks in the channel.}];
+
+  let arguments = (ins
+    FLOW_CollectiveReductionOpAttr:$reduction_op,
+    FLOW_CollectiveElementTypeAttr:$element_type,
+    FLOW_Tensor:$target,
+    FLOW_ShapeDynamicDims:$target_dims,
+    FLOW_Tensor:$source,
+    FLOW_Channel:$channel,
+    OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands
+  );
+  let results = (outs
+    FLOW_Tensor:$result
+  );
+  let assemblyFormat = [{
+    $reduction_op `,` $element_type `,` $target `,` $source `,` $channel `:`
+    `(` type($target) `,` type($source) `,` type($channel) `)` `->`
+    custom<ShapedTiedResult>(type($result), $target_dims, $tied_operands)
+    attr-dict-with-keyword
+  }];
+  let builders = [
+    OpBuilder<(ins
+      "CollectiveReductionOpAttr":$reduction_op,
+      "CollectiveElementTypeAttr":$element_type,
+      "Value":$target,
+      "Value":$source,
+      "Value":$channel)>,
+  ];
+}
+
+def FLOW_CollectiveReduceScatterOp : FLOW_Op<"collective.reduce_scatter", [
+  AllTypesMatch<["target", "result"]>,
+  DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+    "getTiedResult",
+    "getTiedResultOperandIndex",
+    "getTiedResultOperandIndices",
+  ]>,
+]> {
+  let summary = [{performs reduce and scatter operations}];
+  let description = [{The operation reduces data across all the ranks in the channel and
+    scatters the result to each rank.}];
+
+  let arguments = (ins
+    FLOW_CollectiveReductionOpAttr:$reduction_op,
+    FLOW_CollectiveElementTypeAttr:$element_type,
+    FLOW_Tensor:$target,
+    FLOW_ShapeDynamicDims:$target_dims,
+    FLOW_Tensor:$source,
+    FLOW_Channel:$channel,
+    OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands
+  );
+  let results = (outs
+    FLOW_Tensor:$result
+  );
+  let assemblyFormat = [{
+    $reduction_op `,` $element_type `,` $target `,` $source `,` $channel `:`
+    `(` type($target) `,` type($source) `,` type($channel) `)` `->`
+    custom<ShapedTiedResult>(type($result), $target_dims, $tied_operands)
+    attr-dict-with-keyword
+  }];
+  let builders = [
+    OpBuilder<(ins
+      "CollectiveReductionOpAttr":$reduction_op,
+      "CollectiveElementTypeAttr":$element_type,
+      "Value":$target,
+      "Value":$source,
+      "Value":$channel)>,
+  ];
+}
+
 #endif  // IREE_DIALECT_FLOW_OPS
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
index 863b855..7161888 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
@@ -31,7 +31,8 @@
     }
     // Check that the following dimensions are match the order of `dims`
     for (unsigned j = 1, numDims = dims.size(); j < numDims; j++) {
-      if (map.getDimPosition(i + j) != dims[j]) {
+      unsigned pos = i + j;
+      if (pos >= map.getNumResults() || map.getDimPosition(pos) != dims[j]) {
         return false;
       }
     }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
index 2631e4e..25484ec 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
@@ -18,3 +18,22 @@
 // Check that we collapse dimensions.
 // CHECK: @multi_reduce_dim
 // CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction"]
+
+// -----
+
+// Collapsing is not supported when an input is broadcasted; we can't collapse
+// the input from tensor<4xf32> to tensor<32xf32> for example.
+
+func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
+  %empty = tensor.empty() : tensor<f32>
+  %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor<f32>) {
+  ^bb0(%arg2: f32, %arg3: f32, %out: f32):
+    %div = arith.divf %arg2, %arg3 : f32
+    %add = arith.addf %out, %div : f32
+    linalg.yield %add : f32
+  } -> tensor<f32>
+  return %reduce : tensor<f32>
+}
+
+// CHECK: @input_broadcast
+// CHECK-NOT: tensor.collapse_shape
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 3a9fb2f..5565bdf 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -482,6 +482,152 @@
   }
 };
 
+// -----------------------------------------------------------------------------
+// Collective Ops
+// -----------------------------------------------------------------------------
+
+struct ConvertAllGatherOp
+    : public OpConversionPattern<IREE::Flow::CollectiveAllGatherOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto shape = op.getSource().getType().cast<ShapedType>();
+    auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
+        op.getContext(), IREE::Stream::CollectiveKind::AllGather,
+        /*reduction=*/std::nullopt,
+        static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
+
+    auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    auto elementCount = rewriter.create<arith::ConstantIndexOp>(
+        op.getLoc(), shape.getNumElements());
+    auto newTargetCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+    auto newSourceCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+
+    rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
+        op, collectiveAttr, adaptor.getTarget(),
+        /*target_size=*/newTargetCast.resourceSize,
+        /*target_offset=*/zeroOffset,
+        /*target_end=*/newTargetCast.resourceSize,
+        /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
+        /*source_size=*/newSourceCast.resourceSize,
+        /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
+        /*source_length=*/newSourceCast.resourceSize, elementCount,
+        adaptor.getChannel(),
+        /*param=*/mlir::Value(), getAffinityFor(op));
+    return success();
+  }
+};
+
+struct ConvertAllReduceOp
+    : public OpConversionPattern<IREE::Flow::CollectiveAllReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto shape = op.getType().cast<ShapedType>();
+    auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
+        op.getContext(), IREE::Stream::CollectiveKind::AllReduce,
+        static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
+        static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
+
+    auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    auto elementCount = rewriter.create<arith::ConstantIndexOp>(
+        op.getLoc(), shape.getNumElements());
+    auto newTargetCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+    auto newSourceCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+
+    rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
+        op, collectiveAttr, adaptor.getTarget(),
+        /*target_size=*/newTargetCast.resourceSize,
+        /*target_offset=*/zeroOffset,
+        /*target_end=*/newTargetCast.resourceSize,
+        /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
+        /*source_size=*/newSourceCast.resourceSize,
+        /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
+        /*source_length=*/newSourceCast.resourceSize, elementCount,
+        adaptor.getChannel(),
+        /*param=*/mlir::Value(), getAffinityFor(op));
+    return success();
+  }
+};
+
+struct ConvertReduceScatterOp
+    : public OpConversionPattern<IREE::Flow::CollectiveReduceScatterOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto shape = op.getType().cast<ShapedType>();
+    auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
+        op.getContext(), IREE::Stream::CollectiveKind::ReduceScatter,
+        static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
+        static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
+
+    auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    auto elementCount = rewriter.create<arith::ConstantIndexOp>(
+        op.getLoc(), shape.getNumElements());
+    auto newTargetCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+    auto newSourceCast =
+        consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+
+    rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
+        op, collectiveAttr, adaptor.getTarget(),
+        /*target_size=*/newTargetCast.resourceSize,
+        /*target_offset=*/zeroOffset,
+        /*target_end=*/newTargetCast.resourceSize,
+        /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
+        /*source_size=*/newSourceCast.resourceSize,
+        /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
+        /*source_length=*/newSourceCast.resourceSize, elementCount,
+        adaptor.getChannel(),
+        /*param=*/mlir::Value(), getAffinityFor(op));
+    return success();
+  }
+};
+
+struct ConvertChannelDefaultOp
+    : public OpConversionPattern<IREE::Flow::ChannelDefaultOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    IREE::Stream::AffinityAttr affinityAttr;
+    rewriter.replaceOpWithNewOp<IREE::Stream::ChannelDefaultOp>(op,
+                                                                affinityAttr);
+    return success();
+  }
+};
+
+struct ConvertChannelCountOp
+    : public OpConversionPattern<IREE::Flow::ChannelCountOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::ChannelCountOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<IREE::Stream::ChannelCountOp>(
+        op, adaptor.getOperands());
+    return success();
+  }
+};
+
+struct ConvertChannelRankOp
+    : public OpConversionPattern<IREE::Flow::ChannelRankOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Flow::ChannelRankOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<IREE::Stream::ChannelRankOp>(
+        op, adaptor.getOperands());
+    return success();
+  }
+};
+
 }  // namespace
 
 void populateFlowToStreamConversionPatterns(MLIRContext *context,
@@ -495,6 +641,13 @@
   patterns.insert<ConvertDispatchOp>(typeConverter, context);
   patterns.insert<ConvertExecutableOp>(typeConverter, context);
   patterns.insert<ConvertReturnOp>(typeConverter, context);
+  // collective ops
+  patterns.insert<ConvertAllGatherOp>(typeConverter, context);
+  patterns.insert<ConvertAllReduceOp>(typeConverter, context);
+  patterns.insert<ConvertChannelCountOp>(typeConverter, context);
+  patterns.insert<ConvertChannelDefaultOp>(typeConverter, context);
+  patterns.insert<ConvertChannelRankOp>(typeConverter, context);
+  patterns.insert<ConvertReduceScatterOp>(typeConverter, context);
 }
 
 void populateFlowToStreamConversionPatterns(MLIRContext *context,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
index c027a47..b6a902b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
@@ -16,6 +16,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "collective_ops.mlir",
             "dispatch_ops.mlir",
             "executable_ops.mlir",
             "tensor_ops.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
index 82ea148..c038d00 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "collective_ops.mlir"
     "dispatch_ops.mlir"
     "executable_ops.mlir"
     "tensor_ops.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/collective_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/collective_ops.mlir
new file mode 100644
index 0000000..9e800f6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/collective_ops.mlir
@@ -0,0 +1,68 @@
+// RUN: iree-opt --split-input-file --iree-stream-conversion %s | FileCheck %s
+
+// CHECK-LABEL: @channel_count
+func.func @channel_count() -> index {
+  // CHECK: [[CHANNEL:%.+]] = stream.channel.default : !stream.channel
+  // CHECK: [[COUNT:%.+]] = stream.channel.count [[CHANNEL]] : index
+  // CHECK: return [[COUNT]] : index
+  %channel_default = flow.channel.default : !flow.channel
+  %count = flow.channel.count %channel_default : index
+  return %count : index
+}
+
+//-----
+
+// CHECK-LABEL: @channel_rank
+func.func @channel_rank() -> index {
+  // CHECK: [[CHANNEL:%.+]] = stream.channel.default : !stream.channel
+  // CHECK: [[RANK:%.+]] = stream.channel.rank [[CHANNEL]] : index
+  // CHECK: return [[RANK]] : index
+  %channel_default = flow.channel.default : !flow.channel
+  %rank = flow.channel.rank %channel_default : index
+  return %rank : index
+}
+
+//-----
+
+// CHECK-LABEL: @all_reduce_sum
+func.func @all_reduce_sum(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+  // CHECK: stream.channel.default
+  // CHECK: stream.tensor.empty : tensor<2304xf32>
+  // CHECK: stream.async.collective<all_reduce with sum : f32>
+  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2304xf32>
+  %channel_default = flow.channel.default : !flow.channel
+  %1 = flow.tensor.empty : tensor<2304xf32>
+  %2 = flow.collective.all_reduce sum, f32, %1, %0, %channel_default : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> tensor<2304xf32>
+  %3 = hal.tensor.export %2 : tensor<2304xf32> -> !hal.buffer_view
+  return %3 : !hal.buffer_view
+}
+
+//-----
+
+// CHECK-LABEL: @allgather
+func.func @allgather(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+  // CHECK: stream.channel.default
+  // CHECK: stream.tensor.empty : tensor<1024xf32>
+  // CHECK: stream.async.collective<all_gather : f32>
+  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<512xf32>
+  %channel_default = flow.channel.default : !flow.channel
+  %1 = flow.tensor.empty : tensor<1024xf32>
+  %2 = flow.collective.all_gather f32, %1, %0, %channel_default : (tensor<1024xf32>, tensor<512xf32>, !flow.channel) -> tensor<1024xf32>
+  %3 = hal.tensor.export %2 : tensor<1024xf32> -> !hal.buffer_view
+  return %3 : !hal.buffer_view
+}
+
+//-----
+
+// CHECK-LABEL: @reduce_scatter
+func.func @reduce_scatter(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+  // CHECK: stream.channel.default
+  // CHECK: stream.tensor.empty : tensor<2x2xf32>
+  // CHECK: stream.async.collective<reduce_scatter with sum : f32>
+  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<4x2xf32>
+  %channel_default = flow.channel.default : !flow.channel
+  %1 = flow.tensor.empty : tensor<2x2xf32>
+  %2 = flow.collective.reduce_scatter sum, f32, %1, %0, %channel_default : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> tensor<2x2xf32>
+  %3 = hal.tensor.export %2 : tensor<2x2xf32> -> !hal.buffer_view
+  return %3 : !hal.buffer_view
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index baa36a5..0b0acc8 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
 #include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h"
 #include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h"
 #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
@@ -202,8 +203,13 @@
 
     // Allow unknown types to pass through; these come from custom dialects that
     // may be mixed into the IR we are converting.
-    typeConverter.addConversion(
-        [](Type type) { return !type.isa<TensorType>() ? type : Type{}; });
+    typeConverter.addConversion([=](Type type) -> Type {
+      // convert flow.channel into stream.channel
+      if (type.isa<IREE::Flow::ChannelType>())
+        return IREE::Stream::ChannelType::get(context);
+
+      return !type.isa<TensorType>() ? type : Type{};
+    });
 
     // Disallow tensor dialects; the goal here is to remove all tensors and
     // turn them into stream resource ops.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/BUILD b/compiler/src/iree/compiler/InputConversion/MHLO/BUILD
index 4127a84..2de86f4 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/BUILD
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/BUILD
@@ -47,6 +47,7 @@
     name = "MHLO",
     srcs = [
         "BroadcastingToLinalgPatterns.cpp",
+        "ConvertCollectiveOps.cpp",
         "ConvertComplexToReal.cpp",
         "ConvertMHLOToFlow.cpp",
         "ConvertMHLOToFlow.h",
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt
index c713b78..26d37ce 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt
@@ -41,6 +41,7 @@
     "Passes.h"
   SRCS
     "BroadcastingToLinalgPatterns.cpp"
+    "ConvertCollectiveOps.cpp"
     "ConvertComplexToReal.cpp"
     "ConvertMHLOToFlow.cpp"
     "ConvertMHLOToFlow.h"
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
new file mode 100644
index 0000000..94607d9
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
@@ -0,0 +1,449 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <iree/compiler/Dialect/Flow/IR/FlowTypes.h>
+
+#include <optional>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
+#include "iree/compiler/InputConversion/MHLO/Passes.h"
+#include "iree/compiler/InputConversion/MHLO/Rewriters.h"
+#include "mhlo/IR/hlo_ops.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace MHLO {
+
+// Work in progress. The implementation is planned as several stages.
+//
+// For the first stage, a few simplifications are made to support simple models.
+//
+//   1. Single stream with deterministic order of execution
+//   2. Single replica group for all collective ops
+//   3. Only replicas without partition_id used
+//
+// These allow us to use a default channel for all communications, and there is
+// 1:1 mapping from the replica IDs to the communication ranks. The attribute,
+// use_global_device_ids, is always set in this case.
+//
+// The next stage is to support multiple replica groups. This needs a channel
+// creation with a subset of processes, which should have another communication
+// among the group. A possible strategy is to have the root process in the group
+// (the first rank of the group) creates a channel and the other processes query
+// the channel info from the root process. A key-value store using gRPC might be
+// a good solution.
+//
+// Supporting partition_id comes next. This includes the support for various
+// mode combinations for cross-replica and cross partition communication. See
+// the stablehlo specification for more details about the different modes.
+
+namespace {
+
+static std::optional<IREE::Flow::CollectiveElementType>
+convertToFlowCollectiveElementType(Type type) {
+  if (type.isF32()) {
+    return IREE::Flow::CollectiveElementType::Float32;
+  }
+
+  if (type.isInteger(32)) {
+    if (type.isSignedInteger()) {
+      return IREE::Flow::CollectiveElementType::Sint32;
+    } else {
+      return IREE::Flow::CollectiveElementType::Uint32;
+    }
+  }
+
+  if (type.isF16()) {
+    return IREE::Flow::CollectiveElementType::Float16;
+  }
+
+  if (type.isInteger(8)) {
+    if (type.isSignedInteger()) {
+      return IREE::Flow::CollectiveElementType::Sint8;
+    } else {
+      return IREE::Flow::CollectiveElementType::Uint8;
+    }
+  }
+
+  if (type.isInteger(16)) {
+    if (type.isSignedInteger()) {
+      return IREE::Flow::CollectiveElementType::Sint16;
+    } else {
+      return IREE::Flow::CollectiveElementType::Uint16;
+    }
+  }
+
+  if (type.isBF16()) {
+    return IREE::Flow::CollectiveElementType::BFloat16;
+  }
+
+  if (type.isF64()) {
+    return IREE::Flow::CollectiveElementType::Float64;
+  }
+
+  if (type.isInteger(64)) {
+    if (type.isSignedInteger()) {
+      return IREE::Flow::CollectiveElementType::Sint64;
+    } else {
+      return IREE::Flow::CollectiveElementType::Uint64;
+    }
+  }
+
+  return std::nullopt;
+}
+
+static std::optional<IREE::Flow::CollectiveReductionOp>
+convertToFlowCollectiveReductionOp(const Operation &op) {
+  if (isa<mhlo::AddOp>(op)) {
+    return IREE::Flow::CollectiveReductionOp::ReductionSum;
+  } else if (isa<mhlo::MulOp>(op)) {
+    return IREE::Flow::CollectiveReductionOp::ReductionProduct;
+  } else if (isa<mhlo::MinOp>(op)) {
+    return IREE::Flow::CollectiveReductionOp::ReductionMinimum;
+  } else if (isa<mhlo::MaxOp>(op)) {
+    return IREE::Flow::CollectiveReductionOp::ReductionMaximum;
+  } else {
+    // TODO: we may be able to detect an average operation and convert it
+    // into IREE::Flow::CollectiveReductionOp::ReductionAverage.
+    return std::nullopt;
+  }
+}
+
+static IREE::Flow::CollectiveElementTypeAttr getCollectiveElementTypeAttr(
+    MLIRContext *context, RankedTensorType type) {
+  std::optional<IREE::Flow::CollectiveElementType> collectiveElemType =
+      convertToFlowCollectiveElementType(type.getElementType());
+  if (!collectiveElemType) {
+    return IREE::Flow::CollectiveElementTypeAttr();
+  }
+  return IREE::Flow::CollectiveElementTypeAttr::get(context,
+                                                    *collectiveElemType);
+}
+
+}  // namespace
+
+/// Converts mhlo.replica_id to flow.channel.default + flow.channel.rank.
+/// TODO(okkwon): this assumes that there is no partition so that there is a 1:1
+/// mapping between the replica ID and the process ID.
+struct ReplicaIdOpConversion : public OpConversionPattern<mhlo::ReplicaIdOp> {
+  using OpConversionPattern<mhlo::ReplicaIdOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::ReplicaIdOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(loc);
+    auto rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+    auto resultType = op.getType().cast<RankedTensorType>();  // tensor<ui32>
+    auto elemType = resultType.getElementType();
+    // index -> ui32
+    auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, rank);
+    // tensor<ui32>
+    auto rankTensor = rewriter.create<tensor::FromElementsOp>(
+        loc, resultType, rankElem.getResult());
+    rewriter.replaceOp(op, rankTensor.getResult());
+    return success();
+  }
+};
+
+struct AllGatherOpConversion : public OpConversionPattern<mhlo::AllGatherOp> {
+  using OpConversionPattern<mhlo::AllGatherOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::AllGatherOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    if (!op.getUseGlobalDeviceIds()) {
+      return rewriter.notifyMatchFailure(op, "must use global device IDs");
+    }
+
+    // Check there is only one group in the replica_groups
+    ShapedType replicaGroupType = op.getReplicaGroups().getType();
+    if (replicaGroupType.getRank() != 2 ||
+        replicaGroupType.getDimSize(0) != 1) {
+      return rewriter.notifyMatchFailure(op,
+                                         "must have a single replica group");
+    }
+
+    // Currently only the default channel is used.
+
+    // Create a default channel.
+    auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(loc);
+
+    // Get the collective element type attribute.
+    auto resultType = op.getResult().getType().cast<RankedTensorType>();
+    IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
+        getCollectiveElementTypeAttr(op.getContext(), resultType);
+    if (!elementTypeAttr) {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported element type for collective op");
+    }
+
+    // When all_gather_dim != 0, we need to transpose between 0 and
+    // all_gather_dim before and after the flow allgather op.
+    uint64_t allGatherDim = op.getAllGatherDim();
+    auto inputType = op.getOperand().getType().cast<RankedTensorType>();
+    SmallVector<int64_t> gatherInputShape(inputType.getShape());
+    Value gatherInput = op.getOperand();
+    DenseIntElementsAttr permutationAttr;
+    SmallVector<int64_t> gatherResultShape(resultType.getShape());
+
+    if (allGatherDim != 0) {
+      SmallVector<int64_t> permutation =
+          llvm::to_vector(llvm::seq<int64_t>(0, gatherResultShape.size()));
+      std::swap(permutation[0], permutation[allGatherDim]);
+      permutationAttr = rewriter.getI64VectorAttr(permutation);
+      std::swap(gatherInputShape[0], gatherInputShape[allGatherDim]);
+      std::swap(gatherResultShape[0], gatherResultShape[allGatherDim]);
+      // Transpose the input.
+      gatherInput = rewriter
+                        .create<mhlo::TransposeOp>(
+                            loc,
+                            RankedTensorType::get(gatherInputShape,
+                                                  resultType.getElementType()),
+                            gatherInput, permutationAttr)
+                        .getResult();
+    }
+
+    // Create an empty tensor for the result.
+    Value target = rewriter.create<tensor::EmptyOp>(
+        loc, gatherResultShape, resultType.getElementType());
+    Value gatherResult =
+        rewriter
+            .create<IREE::Flow::CollectiveAllGatherOp>(
+                op.getLoc(), elementTypeAttr, target, gatherInput, channel)
+            .getResult();
+
+    if (allGatherDim != 0) {
+      gatherResult = rewriter
+                         .create<mhlo::TransposeOp>(
+                             loc, resultType, gatherResult, permutationAttr)
+                         .getResult();
+    }
+
+    rewriter.replaceOp(op, gatherResult);
+    return success();
+  }
+};
+
+struct AllReduceOpConversion : public OpConversionPattern<mhlo::AllReduceOp> {
+  using OpConversionPattern<mhlo::AllReduceOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::AllReduceOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    if (!op.getUseGlobalDeviceIds()) {
+      return rewriter.notifyMatchFailure(op, "must use global device IDs");
+    }
+
+    // Check there is only one group in the replica_groups.
+    ShapedType replicaGroupType = op.getReplicaGroups().getType();
+    if (replicaGroupType.getRank() != 2 ||
+        replicaGroupType.getDimSize(0) != 1) {
+      return rewriter.notifyMatchFailure(op,
+                                         "must have a single replica group");
+    }
+
+    // Only single elementwise op is supported.
+    Block &block = op.getComputation().front();
+
+    if (block.empty() || llvm::hasSingleElement(block) ||
+        std::next(block.begin(), 2) != block.end()) {
+      return rewriter.notifyMatchFailure(op, "must have two ops in the block");
+    }
+
+    if (block.getNumArguments() != 2) {
+      return rewriter.notifyMatchFailure(op, "must have two block args");
+    }
+
+    Operation &op1 = block.front();
+    Operation &op2 = *(++block.begin());
+
+    if (op1.getNumResults() != 1 ||
+        !op1.hasTrait<::mlir::OpTrait::Elementwise>()) {
+      return rewriter.notifyMatchFailure(op, "must have elementwise trait");
+    }
+
+    // Convert mhlo reduction op into flow reduction op.
+    std::optional<IREE::Flow::CollectiveReductionOp> redOp =
+        convertToFlowCollectiveReductionOp(op1);
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(op, "unsupported operation.");
+    }
+
+    if (!op2.mightHaveTrait<OpTrait::IsTerminator>()) {
+      return rewriter.notifyMatchFailure(op,
+                                         "the second op must be a terminator");
+    }
+    // Currently only the default channel is used.
+
+    // Create a default channel.
+    auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(loc);
+
+    // Convert mhlo reduction op into flow reduction op.
+    auto reductionOpAttr =
+        IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp);
+
+    auto inputType = op.getOperand().getType().cast<RankedTensorType>();
+
+    // Get the collective element type attribute.
+    IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
+        getCollectiveElementTypeAttr(op.getContext(), inputType);
+    if (!elementTypeAttr) {
+      return rewriter.notifyMatchFailure(op, "unsupported input type");
+    }
+
+    // Create an empty tensor for the result.
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    Value target = rewriter.create<tensor::EmptyOp>(loc, inputShape,
+                                                    inputType.getElementType());
+    auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
+        op.getLoc(), reductionOpAttr, elementTypeAttr, target, op.getOperand(),
+        channel);
+    rewriter.replaceOp(op, allReduceOp.getResult());
+    return success();
+  }
+};
+
+struct ReduceScatterOpConversion
+    : public OpConversionPattern<mhlo::ReduceScatterOp> {
+  using OpConversionPattern<mhlo::ReduceScatterOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::ReduceScatterOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    if (!op.getUseGlobalDeviceIds()) {
+      return rewriter.notifyMatchFailure(op, "must use global device IDs");
+    }
+
+    // Check if there is only one group in the replica_groups.
+    ShapedType replicaGroupType = op.getReplicaGroups().getType();
+    if (replicaGroupType.getRank() != 2 ||
+        replicaGroupType.getDimSize(0) != 1) {
+      return rewriter.notifyMatchFailure(op,
+                                         "must have a single replica group");
+    }
+
+    // Only single elementwise op is supported.
+    Block &block = op.getComputation().front();
+
+    if (block.empty() || llvm::hasSingleElement(block) ||
+        std::next(block.begin(), 2) != block.end()) {
+      return rewriter.notifyMatchFailure(op, "must have two ops in the block");
+    }
+
+    if (block.getNumArguments() != 2) {
+      return rewriter.notifyMatchFailure(op, "must have two block args");
+    }
+
+    Operation &op1 = block.front();
+    Operation &op2 = *(++block.begin());
+
+    if (op1.getNumResults() != 1 ||
+        !op1.hasTrait<::mlir::OpTrait::Elementwise>()) {
+      return rewriter.notifyMatchFailure(op, "must have elementwise trait");
+    }
+
+    // Convert mhlo reduction op into flow reduction op.
+    std::optional<IREE::Flow::CollectiveReductionOp> redOp =
+        convertToFlowCollectiveReductionOp(op1);
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(op, "unsupported operation.");
+    }
+
+    if (!op2.mightHaveTrait<OpTrait::IsTerminator>()) {
+      return rewriter.notifyMatchFailure(op,
+                                         "the second op must be a terminator");
+    }
+
+    // Convert mhlo reduction op into flow reduction op.
+    auto reductionOpAttr =
+        IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp);
+
+    // Currently only the default channel is used.
+
+    // Create a default channel.
+    auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(loc);
+
+    // Get the collective element type attribute.
+    auto resultType = op.getResult().getType().cast<RankedTensorType>();
+    IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
+        getCollectiveElementTypeAttr(op.getContext(), resultType);
+    if (!elementTypeAttr) {
+      return rewriter.notifyMatchFailure(op, "unsupported input type");
+    }
+
+    // When scatter_dimension != 0, we need to transpose between 0 and
+    // scatter_dimension before and after the flow reduce_scatter op.
+    uint64_t scatterDim = op.getScatterDimension();
+    auto inputType = op.getOperand().getType().cast<RankedTensorType>();
+    SmallVector<int64_t> reduceInputShape(inputType.getShape());
+    Value reduceInput = op.getOperand();
+    DenseIntElementsAttr permutationAttr;
+
+    SmallVector<int64_t> scatterResultShape(resultType.getShape());
+    auto elemType = resultType.getElementType();
+
+    if (scatterDim != 0) {
+      SmallVector<int64_t> permutation =
+          llvm::to_vector(llvm::seq<int64_t>(0, scatterResultShape.size()));
+      std::swap(permutation[0], permutation[scatterDim]);
+      permutationAttr = rewriter.getI64VectorAttr(permutation);
+      std::swap(reduceInputShape[0], reduceInputShape[scatterDim]);
+      std::swap(scatterResultShape[0], scatterResultShape[scatterDim]);
+      // Transpose the input.
+      reduceInput =
+          rewriter
+              .create<mhlo::TransposeOp>(
+                  loc, RankedTensorType::get(reduceInputShape, elemType),
+                  reduceInput, permutationAttr)
+              .getResult();
+    }
+
+    // Create an empty tensor for the result.
+    Value target = rewriter.create<tensor::EmptyOp>(
+        loc, scatterResultShape, resultType.getElementType());
+    Value scatterResult = rewriter
+                              .create<IREE::Flow::CollectiveReduceScatterOp>(
+                                  op.getLoc(), reductionOpAttr, elementTypeAttr,
+                                  target, reduceInput, channel)
+                              .getResult();
+
+    if (scatterDim != 0) {
+      scatterResult = rewriter
+                          .create<mhlo::TransposeOp>(
+                              loc, resultType, scatterResult, permutationAttr)
+                          .getResult();
+    }
+
+    rewriter.replaceOp(op, scatterResult);
+    return success();
+  }
+};
+
+void populateMHLOCollectiveOpsConversionPatterns(MLIRContext *context,
+                                                 TypeConverter &typeConverter,
+                                                 RewritePatternSet &patterns) {
+  patterns.insert<AllGatherOpConversion>(typeConverter, context);
+  patterns.insert<AllReduceOpConversion>(typeConverter, context);
+  patterns.insert<ReduceScatterOpConversion>(typeConverter, context);
+  patterns.insert<ReplicaIdOpConversion>(typeConverter, context);
+}
+
+}  // namespace MHLO
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 03aa781..ea51f00 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -356,6 +356,7 @@
     // TODO: Collapse/rework all of these patterns once the consolidation
     // lands. There is little reason to have these so spread out.
     populateMHLOToFlowPatterns(context, patterns);
+
     chlo::populateDecomposeChloPatterns(context, &patterns);
     populateMHLOBroadcastingToLinalgPatterns(context, *typeConverter, patterns);
     mhlo::populateScalarHloToArithmeticConversionPatterns(
@@ -365,6 +366,8 @@
                                                     patterns);
     populateMHLOComplexToRealPatterns(context, *typeConverter, patterns);
 
+    populateMHLOCollectiveOpsConversionPatterns(context, *typeConverter,
+                                                patterns);
     // TODO(*): expose patterns that do this much better from
     // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp
 
@@ -386,8 +389,13 @@
         context);
     patterns.insert<GenericTypeConvert>(
         ml_program::GlobalStoreOp::getOperationName(), *typeConverter, context);
-
+    // This is needed when converting mhlo::ReplicaIDOp.
+    patterns.insert<GenericTypeConvert>(
+        tensor::FromElementsOp::getOperationName(), *typeConverter, context);
+    patterns.insert<GenericTypeConvert>(
+        arith::IndexCastUIOp::getOperationName(), *typeConverter, context);
     ConversionTarget target(getContext());
+
     auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
     auto isLegallyTypedOp = [&](Operation *op) -> bool {
       for (Type type : op->getResultTypes()) {
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td b/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td
index 1855004..984b92a 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td
@@ -55,4 +55,5 @@
   let constructor = "mlir::iree_compiler::MHLO::createTestMHLOConvertComplexToRealPass()";
 }
 
+
 #endif // IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h b/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h
index ac996f0..e839f38 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h
@@ -27,6 +27,11 @@
                                               TypeConverter &typeConverter,
                                               RewritePatternSet &patterns);
 
+/// Populates patterns to convert MHLO collective ops to Stream ops.
+void populateMHLOCollectiveOpsConversionPatterns(MLIRContext *context,
+                                                 TypeConverter &typeConverter,
+                                                 RewritePatternSet &patterns);
+
 /// Populates patterns to convert MHLO/CHLO arithmetic on complex tensors to
 /// equivalent HLO level real arithmetic.
 void populateMHLOComplexToRealPatterns(MLIRContext *context,
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
index 99e5b4f..5ac112d 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -20,6 +20,7 @@
         [
             "broadcasting.mlir",
             "convert_mhlo_to_linalg_ext.mlir",
+            "convert_collective_ops.mlir",
             "convert_complex_to_real.mlir",
             "convert_structural_types.mlir",
             "dynamic_shape.mlir",
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
index 8404a20..8b9aa57 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "broadcasting.mlir"
+    "convert_collective_ops.mlir"
     "convert_complex_to_real.mlir"
     "convert_mhlo_to_linalg_ext.mlir"
     "convert_structural_types.mlir"
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
new file mode 100644
index 0000000..a93557d
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -0,0 +1,167 @@
+// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors --canonicalize -cse %s | FileCheck %s
+
+// CHECK-LABEL: @replica_id
+func.func @replica_id() -> tensor<ui32> {
+  // CHECK-DAG: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK-DAG: [[RANK:%.+]] = flow.channel.rank [[CHANNEL]] : index
+  // CHECK-DAG: [[CAST:%.+]] = arith.index_castui [[RANK]] : index to i32
+  // CHECK-DAG: [[TENSOR:%.+]] = tensor.from_elements [[CAST]] : tensor<i32>
+  // CHECK-DAG: return [[TENSOR]] : tensor<i32>
+  %id = mhlo.replica_id : tensor<ui32>
+  return %id : tensor<ui32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_reduce_sum
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+  // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+  // CHECK: return [[ALLREDUCE]] : tensor<2304xf32>
+  %out = "mhlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+      %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+      mhlo.return %sum : tensor<f32>
+    }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+  return %out : tensor<2304xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_reduce_product
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.all_reduce product, f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+  // CHECK: return [[OP]] : tensor<2304xf32>
+  %out = "mhlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+      %mul = mhlo.multiply %arg0, %arg1 : tensor<f32>
+      mhlo.return %mul : tensor<f32>
+    }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+  return %out : tensor<2304xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_reduce_minimum
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+func.func @all_reduce_minimum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.all_reduce minimum, f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+  // CHECK: return [[OP]] : tensor<2304xf32>
+  %out = "mhlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+      %mul = mhlo.minimum %arg0, %arg1 : tensor<f32>
+      mhlo.return %mul : tensor<f32>
+    }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+  return %out : tensor<2304xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_reduce_maximum
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+func.func @all_reduce_maximum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.all_reduce maximum, f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+  // CHECK: return [[OP]] : tensor<2304xf32>
+  %out = "mhlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+      %mul = mhlo.maximum %arg0, %arg1 : tensor<f32>
+      mhlo.return %mul : tensor<f32>
+    }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+  return %out : tensor<2304xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_gather_dim_0
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<512xf32>) -> tensor<1024xf32>
+func.func @all_gather_dim_0(%input : tensor<512xf32>) -> tensor<1024xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<1024xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.all_gather f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<1024xf32>, tensor<512xf32>, !flow.channel) -> [[EMPTY]] as tensor<1024xf32>
+  // CHECK: return [[OP]] : tensor<1024xf32>
+  %out = "mhlo.all_gather"(%input) {all_gather_dim = 0 : i64,
+     channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+     replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
+     use_global_device_ids} : (tensor<512xf32>) -> tensor<1024xf32>
+  return %out : tensor<1024xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @all_gather_dim_1
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x2xf32>) -> tensor<2x4xf32>
+func.func @all_gather_dim_1(%input : tensor<2x2xf32>) -> tensor<2x4xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: tensor.empty() : tensor<2x2xf32>
+  // CHECK: [[TRANSPOSE_ARG:%.+]] = linalg.generic
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<4x2xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.all_gather f32, [[EMPTY]], [[TRANSPOSE_ARG]], %channel_default  : (tensor<4x2xf32>, tensor<2x2xf32>, !flow.channel) -> [[EMPTY]] as tensor<4x2xf32>
+  // CHECK: tensor.empty() : tensor<2x4xf32>
+  // CHECK: [[TRANSPOSE_OUT:%.+]] = linalg.generic
+  // CHECK: return [[TRANSPOSE_OUT]] : tensor<2x4xf32>
+  %out = "mhlo.all_gather"(%input) {all_gather_dim = 1 : i64,
+     channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+     replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
+     use_global_device_ids} : (tensor<2x2xf32>) -> tensor<2x4xf32>
+  return %out : tensor<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_scatter_dim_0
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x2xf32>) -> tensor<2x2xf32>
+func.func @reduce_scatter_dim_0(%input : tensor<4x2xf32>) -> tensor<2x2xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2x2xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.reduce_scatter sum, f32, [[EMPTY]], [[ARG0]], %channel_default  : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> [[EMPTY]] as tensor<2x2xf32>
+  // CHECK: return [[OP]] : tensor<2x2xf32>
+  %out = "mhlo.reduce_scatter"(%input) ({
+  ^bb0(%arg0: tensor<f32> , %arg1: tensor<f32>) :
+    %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+    mhlo.return %sum : tensor<f32>
+  }) {scatter_dimension = 0 : i64,
+      channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
+      use_global_device_ids} : (tensor<4x2xf32>) -> tensor<2x2xf32>
+  return %out : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_scatter_dim_1
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x4xf32>) -> tensor<2x2xf32>
+func.func @reduce_scatter_dim_1(%input : tensor<2x4xf32>) -> tensor<2x2xf32> {
+  // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+  // CHECK: tensor.empty() : tensor<4x2xf32>
+  // CHECK: [[TRANSPOSE_ARG:%.+]] = linalg.generic
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2x2xf32>
+  // CHECK: [[OP:%.+]] = flow.collective.reduce_scatter sum, f32, [[EMPTY]], [[TRANSPOSE_ARG]], %channel_default  : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> [[EMPTY]] as tensor<2x2xf32>
+  // CHECK: [[TRANSPOSE_OUT:%.+]] = linalg.generic
+  // CHECK: return [[TRANSPOSE_OUT]] : tensor<2x2xf32>
+  %out = "mhlo.reduce_scatter"(%input) ({
+  ^bb0(%arg0: tensor<f32> , %arg1: tensor<f32>) :
+    %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+    mhlo.return %sum : tensor<f32>
+  }) {scatter_dimension = 1 : i64,
+      channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
+      use_global_device_ids} : (tensor<2x4xf32>) -> tensor<2x2xf32>
+  return %out : tensor<2x2xf32>
+}
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index c88700e..6e5fd65 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -296,8 +296,10 @@
 
 // Specifies the reduction operator of a collective reduction operation.
 enum iree_hal_collective_reduction_e {
+  // Specifies that the reduction operation is unspecified.
+  IREE_HAL_COLLECTIVE_REDUCTION_NONE = 0,
   // Specifies that the reduction operation computes a sum (addition).
-  IREE_HAL_COLLECTIVE_REDUCTION_SUM = 0,
+  IREE_HAL_COLLECTIVE_REDUCTION_SUM = 1,
   // Specifies that the reduction operation computes a product (multiplication).
   IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT,
   // Specifies that the reduction operation computes a minimum (min).
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index ddb3c8b..ef0a82c 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -302,7 +302,8 @@
   // We could multiplex channels but it'd be better to surface that to the
   // compiler so that it can emit the right rank math.
   int requested_count = iree_math_count_ones_u64(queue_affinity);
-  if (requested_count != 1) {
+  // TODO(#12206): properly assign affinity in the compiler.
+  if (requested_count != 64 && requested_count != 1) {
     return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                             "exactly one participant is allowed in a "
                             "channel but %d were specified",