Support `num_replicas` and `num_partitions` (#13288)

Support `num_replicas` and `num_partitions` through cross_replica and
cross_partition for all_reduce, all_gather, and reduce_scatter,
replica_id, and partition_id.

Based on the channel ID and use_global_device_ids, the communication
mode can be different in the 2D grid.

See
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops
for more information.

Note that stablehlo still uses `mhlo.num_replicas` and
`mhlo.num_partitions` to embed the info in the module.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
index 7e71f68..a8604cd 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
@@ -145,10 +145,6 @@
       return rewriter.notifyMatchFailure(
           op, "must not set use_global_device_ids when channel_id <= 0");
     }
-  } else {
-    if (!op.getUseGlobalDeviceIds()) {
-      return rewriter.notifyMatchFailure(op, "must set use_global_device_ids");
-    }
   }
 
   return success();
@@ -212,40 +208,116 @@
   return std::make_pair(color, key);
 }
 
+static DenseIntElementsAttr convertToRankGroupsByCrossReplica(
+    DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+    OpBuilder &builder) {
+  if (numPartitions < 1) {
+    // Treat as a single partition.
+    return replicaGroups;
+  }
+
+  auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+  assert(groupsType.getRank() == 2);
+  int rows = groupsType.getShape()[0];
+  int cols = groupsType.getShape()[1];
+  auto values = replicaGroups.getValues<int64_t>();
+  SmallVector<Attribute> newValues;
+
+  // The number of groups is (rows * numPartitions).
+  for (int i = 0; i < rows; ++i) {
+    for (int p = 0; p < numPartitions; ++p) {
+      // Each group starts here. The group size is the same as the column size.
+      for (int j = 0; j < cols; ++j) {
+        const int index = i * cols + j;
+        const int64_t replicaId = values[index];
+        const int64_t value =
+            (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+        newValues.push_back(builder.getI64IntegerAttr(value));
+      }
+    }
+  }
+
+  auto type =
+      RankedTensorType::get({rows * numPartitions, cols}, builder.getI64Type());
+  return DenseIntElementsAttr::get(type, newValues);
+}
+
+static DenseIntElementsAttr convertToRankGroupsByCrossReplicaAndPartition(
+    DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+    OpBuilder &builder) {
+  if (numPartitions < 1) {
+    // Treat as a single partition.
+    return replicaGroups;
+  }
+
+  auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+  assert(groupsType.getRank() == 2);
+  int rows = groupsType.getShape()[0];
+  int cols = groupsType.getShape()[1];
+  auto values = replicaGroups.getValues<int64_t>();
+  SmallVector<Attribute> newValues;
+
+  // The number of groups is the same as the number of rows.
+  for (int i = 0; i < rows; ++i) {
+    // Each group starts here. The group size is (numPartitions * cols).
+    for (int p = 0; p < numPartitions; ++p) {
+      for (int j = 0; j < cols; ++j) {
+        const int index = i * cols + j;
+        const int64_t replicaId = values[index];
+        const int64_t value =
+            (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+        newValues.push_back(builder.getI64IntegerAttr(value));
+      }
+    }
+  }
+  auto type =
+      RankedTensorType::get({rows, numPartitions * cols}, builder.getI64Type());
+  return DenseIntElementsAttr::get(type, newValues);
+}
+
 /// Creates a channel matching the given |channelHandleAttr| scoped to the
 /// requested group.
 static Value createChannelWithGroupInfo(
     Location loc, mhlo::ChannelHandleAttr channelHandleAttr,
+    int32_t numReplicas, int32_t numPartitions,
     DenseIntElementsAttr replicaGroups, bool useGlobalDeviceIds,
     OpBuilder &builder) {
+  // Set numPartitions to 1 if not set by the user.
+  if (numPartitions == -1) numPartitions = 1;
+
   // Base channel that may be split by the group info.
   Value baseChannel =
       builder.create<IREE::Flow::ChannelDefaultOp>(loc, /*group=*/StringAttr{});
 
-  // TODO(okkwon): Convert replica_groups into flattened IDs.
-  //
-  // Once mhlo exposes `num_replicas` and `num_partitions`,
-  // use the channel ID to determine the collective operation mode, such as
-  // cross_replica, cross_partition, cross_replic_and_partition, and
-  // flattend_ids. Currently, we only supports the flanttend_ids mode.
-  //
-  // int64_t channelId = 0;
-  // if (channelHandleAttr) {
-  //   channelId = channelHandleAttr.getHandle();
-  // }
-
   // No need to split if there is a single group.
   ShapedType replicaGroupType = replicaGroups.getType();
   assert(replicaGroupType.getRank() == 2);
-  if (replicaGroupType.getDimSize(0) == 1) {
+  if (numPartitions == 1 && replicaGroupType.getDimSize(0) == 1) {
     return baseChannel;
   }
 
+  // Convert replica_groups into flattened IDs.
+  DenseIntElementsAttr rankGroups;
+  int64_t channelId = channelHandleAttr ? channelHandleAttr.getHandle() : 0;
+  if (channelId <= 0) {
+    assert(!useGlobalDeviceIds);
+    rankGroups = convertToRankGroupsByCrossReplica(replicaGroups, numPartitions,
+                                                   builder);
+  } else {
+    if (useGlobalDeviceIds) {
+      // already flattened.
+      rankGroups = replicaGroups;
+    } else {
+      rankGroups = convertToRankGroupsByCrossReplicaAndPartition(
+          replicaGroups, numPartitions, builder);
+    }
+  }
+
   // Construct lookups for color and key split parameters.
   // Note that `replica_groups` can be interpreted in multiple ways based on the
   // other attributes.
   auto [color, key] =
-      makeSplitColorAndKey(loc, baseChannel, replicaGroups, builder);
+      makeSplitColorAndKey(loc, baseChannel, rankGroups, builder);
 
   // Split the channel. Note that this is an expensive operation.
   return builder.create<IREE::Flow::ChannelSplitOp>(loc, baseChannel, color,
@@ -268,11 +340,69 @@
       permutationAttr);
 }
 
+static int32_t getNumReplicas(ModuleOp moduleOp) {
+  if (!moduleOp) {
+    return -1;
+  }
+  if (auto numReplicasAttr =
+          moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_replicas")) {
+    return numReplicasAttr.getInt();
+  } else {
+    return -1;
+  }
+}
+
+static int32_t getNumPartitions(ModuleOp moduleOp) {
+  if (!moduleOp) {
+    return -1;
+  }
+  if (auto numPartitionsAttr =
+          moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_partitions")) {
+    return numPartitionsAttr.getInt();
+  } else {
+    return -1;
+  }
+}
+
 }  // namespace
 
-/// Converts mhlo.replica_id to flow.channel.default + flow.channel.rank.
-/// TODO(okkwon): this assumes that there is no partition so that there is a 1:1
-/// mapping between the replica ID and the process ID.
+/// Converts mhlo.partition_id to (flow.channel.rank % numPartitions)
+struct PartitionIdOpConversion
+    : public OpConversionPattern<mhlo::PartitionIdOp> {
+  using OpConversionPattern<mhlo::PartitionIdOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::PartitionIdOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    // PartitionId = rank % numPartitions
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numPartitions = getNumPartitions(moduleOp);
+    Value value;
+    if (numPartitions <= 1) {
+      value = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    } else {
+      auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
+          loc, /*group=*/StringAttr{});
+      Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+      auto cst =
+          rewriter.create<arith::ConstantIndexOp>(loc,
+                                                  /*value=*/numPartitions);
+      value = rewriter.create<arith::RemUIOp>(loc, rank, cst);
+    }
+    auto resultType = op.getType().cast<RankedTensorType>();  // tensor<ui32>
+    auto elemType = resultType.getElementType();
+    // index -> ui32
+    auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, value);
+    // tensor<ui32>
+    auto rankTensor = rewriter.create<tensor::FromElementsOp>(
+        loc, resultType, rankElem.getResult());
+    rewriter.replaceOp(op, rankTensor.getResult());
+    return success();
+  }
+};
+
+/// Converts mhlo.replica_id to floor_div(flow.channel.rank, numPartitions)
 struct ReplicaIdOpConversion : public OpConversionPattern<mhlo::ReplicaIdOp> {
   using OpConversionPattern<mhlo::ReplicaIdOp>::OpConversionPattern;
 
@@ -282,7 +412,17 @@
     auto loc = op.getLoc();
     auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
         loc, /*group=*/StringAttr{});
-    auto rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+    Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+
+    // ReplicaId = floor_div(rank, numPartitions)
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numPartitions = getNumPartitions(moduleOp);
+    auto cst = rewriter.create<arith::ConstantIndexOp>(loc,
+                                                       /*value=*/numPartitions);
+    if (numPartitions > 1) {
+      rank = rewriter.create<arith::DivUIOp>(loc, rank, cst);
+    }
+
     auto resultType = op.getType().cast<RankedTensorType>();  // tensor<ui32>
     auto elemType = resultType.getElementType();
     // index -> ui32
@@ -307,10 +447,14 @@
 
     auto loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Get the collective element type attribute.
     auto resultType = op.getResult().getType().cast<RankedTensorType>();
@@ -392,10 +536,14 @@
 
     auto loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Convert mhlo reduction op into flow reduction op.
     auto reductionOpAttr =
@@ -591,10 +739,14 @@
 
     auto loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Get the collective element type attribute.
     auto resultType = op.getResult().getType().cast<RankedTensorType>();
@@ -652,6 +804,7 @@
   patterns.insert<AllGatherOpConversion>(typeConverter, context);
   patterns.insert<AllReduceOpConversion>(typeConverter, context);
   patterns.insert<AllToAllOpConversion>(typeConverter, context);
+  patterns.insert<PartitionIdOpConversion>(typeConverter, context);
   patterns.insert<ReduceScatterOpConversion>(typeConverter, context);
   patterns.insert<ReplicaIdOpConversion>(typeConverter, context);
 }
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index 2f2e482..e8ab948 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -14,6 +14,53 @@
 
 // -----
 
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @replica_id_with_partitions
+  func.func @replica_id_with_partitions() -> tensor<ui32> {
+    // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+    // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+    // CHECK-DAG: %[[DIV2:.+]] = arith.divui %[[RANK]], %c2 : index
+    // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[DIV2]] : index to i32
+    // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+    // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[TENSOR]] : tensor<i32> to tensor<ui32>
+    // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+    %id = mhlo.replica_id : tensor<ui32>
+    return %id : tensor<ui32>
+  }
+}
+
+// -----
+
+// Returns 0 since num_partitions is not set.
+
+// CHECK-LABEL: @partition_id
+func.func @partition_id() -> tensor<ui32> {
+  // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : tensor<i32>
+  // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[CST0]] : tensor<i32> to tensor<ui32>
+  // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+  %id = mhlo.partition_id : tensor<ui32>
+  return %id : tensor<ui32>
+}
+
+// -----
+
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @partition_id_with_partitions
+  func.func @partition_id_with_partitions() -> tensor<ui32> {
+    // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+    // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+    // CHECK-DAG: %[[REM2:.+]] = arith.remui %[[RANK]], %c2 : index
+    // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[REM2]] : index to i32
+    // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+    // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[TENSOR]] : tensor<i32> to tensor<ui32>
+    // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+    %id = mhlo.partition_id : tensor<ui32>
+    return %id : tensor<ui32>
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @all_reduce_sum
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
 func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
@@ -345,3 +392,81 @@
       replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x2xf32>) -> tensor<2x2xf32>
   return %out : tensor<2x2xf32>
 }
+
+// -----
+
+// flattened_ids: channel_id > 0 && use_global_device_ids = true
+module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 8 : i32 } {
+  // CHECK-LABEL: @flattened_ids
+  // CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+  func.func @flattened_ids(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+    // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+    // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], [[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+    // CHECK: return [[ALLREDUCE]] : tensor<2304xf32>
+    %out = "mhlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+        mhlo.return %sum : tensor<f32>
+      }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+          replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+          use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}
+
+// -----
+
+// cross-replica: channel_id <= 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @cross_replica
+  func.func @cross_replica(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // Cross replica should form groups (0,2,4,6),(1,3,5,7), where each number represents a cell below.
+    // +---+---+
+    // | 0 | 1 |
+    // | 2 | 3 |
+    // | 4 | 5 |
+    // | 6 | 7 |
+    // +---+---+
+    //                          rank:   0    1    2    3    4    5    6    7
+    // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index
+    // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index
+    %out = "mhlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+        mhlo.return %sum : tensor<f32>
+      }) {channel_handle = #mhlo.channel_handle<handle = 0, type = 1>,
+          replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
+         } : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}
+
+// -----
+
+// cross_replica_and_partition: channel_id > 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @cross_replica_and_partition
+  func.func @cross_replica_and_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // Cross replica_and_partition should form groups (0,2,1,3),(4,6,5,7), where each number represents a cell below.
+    // Note that the rank is assigned in a partiton first, e.g., rank 0 and 1 are assigned to cell 0 and 2, respectively.
+    // +---+---+
+    // | 0   1 |
+    // | 2   3 |
+    // |---+---|
+    // | 4   5 |
+    // | 6   7 |
+    // +---+---+
+    //                          rank:   0    1    2    3    4    5    6    7
+    // CHECK: util.switch index from [%c0, %c0, %c0, %c0, %c1, %c1, %c1, %c1] at %channel_rank else %c-1 : index
+    // CHECK: util.switch index from [%c0, %c2, %c1, %c3, %c0, %c2, %c1, %c3] at %channel_rank else %c-1 : index
+    %out = "mhlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+        mhlo.return %sum : tensor<f32>
+      }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+          replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
+         } : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
index f277519..a1474ae 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
@@ -139,8 +139,6 @@
       return rewriter.notifyMatchFailure(
           op, "must not set use_global_device_ids when channel_id <= 0");
     }
-  } else if (!op.getUseGlobalDeviceIds()) {
-    return rewriter.notifyMatchFailure(op, "must set use_global_device_ids");
   }
 
   return success();
@@ -204,46 +202,146 @@
   return std::make_pair(color, key);
 }
 
+static DenseIntElementsAttr convertToRankGroupsByCrossReplica(
+    DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+    OpBuilder &builder) {
+  if (numPartitions < 1) {
+    // Treat as a single partition.
+    return replicaGroups;
+  }
+
+  auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+  assert(groupsType.getRank() == 2);
+  int rows = groupsType.getShape()[0];
+  int cols = groupsType.getShape()[1];
+  auto values = replicaGroups.getValues<int64_t>();
+  SmallVector<Attribute> newValues;
+
+  // The number of groups is (rows * numPartitions).
+  for (int i = 0; i < rows; ++i) {
+    for (int p = 0; p < numPartitions; ++p) {
+      // Each group starts here. The group size is the same as the column size.
+      for (int j = 0; j < cols; ++j) {
+        const int index = i * cols + j;
+        const int64_t replicaId = values[index];
+        const int64_t value =
+            (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+        newValues.push_back(builder.getI64IntegerAttr(value));
+      }
+    }
+  }
+
+  auto type =
+      RankedTensorType::get({rows * numPartitions, cols}, builder.getI64Type());
+  return DenseIntElementsAttr::get(type, newValues);
+}
+
+static DenseIntElementsAttr convertToRankGroupsByCrossReplicaAndPartition(
+    DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+    OpBuilder &builder) {
+  if (numPartitions < 1) {
+    // Treat as a single partition.
+    return replicaGroups;
+  }
+
+  auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+  assert(groupsType.getRank() == 2);
+  int rows = groupsType.getShape()[0];
+  int cols = groupsType.getShape()[1];
+  auto values = replicaGroups.getValues<int64_t>();
+  SmallVector<Attribute> newValues;
+
+  // The number of groups is the same as the number of rows.
+  for (int i = 0; i < rows; ++i) {
+    // Each group starts here. The group size is (numPartitions * cols).
+    for (int p = 0; p < numPartitions; ++p) {
+      for (int j = 0; j < cols; ++j) {
+        const int index = i * cols + j;
+        const int64_t replicaId = values[index];
+        const int64_t value =
+            (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+        newValues.push_back(builder.getI64IntegerAttr(value));
+      }
+    }
+  }
+  auto type =
+      RankedTensorType::get({rows, numPartitions * cols}, builder.getI64Type());
+  return DenseIntElementsAttr::get(type, newValues);
+}
+
 /// Creates a channel matching the given |channelHandleAttr| scoped to the
 /// requested group.
 static Value createChannelWithGroupInfo(
     Location loc, mlir::stablehlo::ChannelHandleAttr channelHandleAttr,
+    int32_t numReplicas, int32_t numPartitions,
     DenseIntElementsAttr replicaGroups, bool useGlobalDeviceIds,
     OpBuilder &builder) {
+  // Set numPartitions to 1 if not set by the user.
+  if (numPartitions == -1) numPartitions = 1;
+
   // Base channel that may be split by the group info.
   Value baseChannel =
       builder.create<IREE::Flow::ChannelDefaultOp>(loc, /*group=*/StringAttr{});
 
-  // TODO(okkwon): Convert replica_groups into flattened IDs.
-  //
-  // Once stablehlo exposes `num_replicas` and `num_partitions`,
-  // use the channel ID to determine the collective operation mode, such as
-  // cross_replica, cross_partition, cross_replic_and_partition, and
-  // flattend_ids. Currently, we only supports the flanttend_ids mode.
-  //
-  // int64_t channelId = 0;
-  // if (channelHandleAttr) {
-  //   channelId = channelHandleAttr.getHandle();
-  // }
-
   // No need to split if there is a single group.
   ShapedType replicaGroupType = replicaGroups.getType();
   assert(replicaGroupType.getRank() == 2);
-  if (replicaGroupType.getDimSize(0) == 1) {
+  if (numPartitions == 1 && replicaGroupType.getDimSize(0) == 1) {
     return baseChannel;
   }
 
+  // Convert replica_groups into flattened IDs.
+  DenseIntElementsAttr rankGroups;
+  int64_t channelId = channelHandleAttr ? channelHandleAttr.getHandle() : 0;
+  if (channelId <= 0) {
+    assert(!useGlobalDeviceIds);
+    rankGroups = convertToRankGroupsByCrossReplica(replicaGroups, numPartitions,
+                                                   builder);
+  } else {
+    if (useGlobalDeviceIds) {
+      // already flattened.
+      rankGroups = replicaGroups;
+    } else {
+      rankGroups = convertToRankGroupsByCrossReplicaAndPartition(
+          replicaGroups, numPartitions, builder);
+    }
+  }
+
   // Construct lookups for color and key split parameters.
   // Note that `replica_groups` can be interpreted in multiple ways based on the
   // other attributes.
   auto [color, key] =
-      makeSplitColorAndKey(loc, baseChannel, replicaGroups, builder);
+      makeSplitColorAndKey(loc, baseChannel, rankGroups, builder);
 
   // Split the channel. Note that this is an expensive operation.
   return builder.create<IREE::Flow::ChannelSplitOp>(loc, baseChannel, color,
                                                     key);
 }
 
+static int32_t getNumReplicas(ModuleOp moduleOp) {
+  if (!moduleOp) {
+    return -1;
+  }
+  if (auto numReplicasAttr =
+          moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_replicas")) {
+    return numReplicasAttr.getInt();
+  } else {
+    return -1;
+  }
+}
+
+static int32_t getNumPartitions(ModuleOp moduleOp) {
+  if (!moduleOp) {
+    return -1;
+  }
+  if (auto numPartitionsAttr =
+          moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_partitions")) {
+    return numPartitionsAttr.getInt();
+  } else {
+    return -1;
+  }
+}
+
 static Value emitTranspose(ConversionPatternRewriter &rewriter, Location loc,
                            Value input, int64_t srcDim, int64_t dstDim) {
   // Creates a transpose op that swaps dimensions srcDim and dstDim in the
@@ -260,22 +358,67 @@
       permutationAttr);
 }
 
-/// Converts stablehlo.replica_id to flow.channel.default + flow.channel.rank.
-/// TODO(okkwon): this assumes that there is no partition so that there is a 1:1
-/// mapping between the replica ID and the process ID.
-struct ReplicaIdOpConversion final
-    : OpConversionPattern<mlir::stablehlo::ReplicaIdOp> {
-  using OpConversionPattern::OpConversionPattern;
+/// Converts stablehlo.partition_id to (flow.channel.rank % numPartitions)
+struct PartitionIdOpConversion
+    : public OpConversionPattern<mlir::stablehlo::PartitionIdOp> {
+  using OpConversionPattern<
+      mlir::stablehlo::PartitionIdOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::stablehlo::PartitionIdOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    // PartitionId = rank % numPartitions
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numPartitions = getNumPartitions(moduleOp);
+    Value value;
+    if (numPartitions <= 1) {
+      value = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    } else {
+      auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
+          loc, /*group=*/StringAttr{});
+      Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+      auto cst =
+          rewriter.create<arith::ConstantIndexOp>(loc,
+                                                  /*value=*/numPartitions);
+      value = rewriter.create<arith::RemUIOp>(loc, rank, cst);
+    }
+    auto resultType = op.getType().cast<RankedTensorType>();  // tensor<ui32>
+    auto elemType = resultType.getElementType();
+    // index -> ui32
+    auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, value);
+    // tensor<ui32>
+    auto rankTensor = rewriter.create<tensor::FromElementsOp>(
+        loc, resultType, rankElem.getResult());
+    rewriter.replaceOp(op, rankTensor.getResult());
+    return success();
+  }
+};
+
+/// Converts stablehlo.replica_id to floor_div(flow.channel.rank, numPartitions)
+struct ReplicaIdOpConversion
+    : public OpConversionPattern<mlir::stablehlo::ReplicaIdOp> {
+  using OpConversionPattern<mlir::stablehlo::ReplicaIdOp>::OpConversionPattern;
 
   LogicalResult matchAndRewrite(
       mlir::stablehlo::ReplicaIdOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
+    auto loc = op.getLoc();
     auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
         loc, /*group=*/StringAttr{});
-    auto rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
-    auto resultType = cast<RankedTensorType>(op.getType());  // tensor<ui32>
-    Type elemType = resultType.getElementType();
+    Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+
+    // ReplicaId = floor_div(rank, numPartitions)
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numPartitions = getNumPartitions(moduleOp);
+    auto cst = rewriter.create<arith::ConstantIndexOp>(loc,
+                                                       /*value=*/numPartitions);
+    if (numPartitions > 1) {
+      rank = rewriter.create<arith::DivUIOp>(loc, rank, cst);
+    }
+
+    auto resultType = op.getType().cast<RankedTensorType>();  // tensor<ui32>
+    auto elemType = resultType.getElementType();
     // index -> ui32
     auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, rank);
     // tensor<ui32>
@@ -299,10 +442,14 @@
 
     Location loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Get the collective element type attribute.
     auto resultType = cast<RankedTensorType>(op.getResult().getType());
@@ -385,10 +532,14 @@
 
     Location loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Convert stablehlo reduction op into flow reduction op.
     auto reductionOpAttr =
@@ -575,10 +726,14 @@
 
     Location loc = op.getLoc();
 
-    // Get the channel used for communication.
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    int32_t numReplicas = getNumReplicas(moduleOp);
+    int32_t numPartitions = getNumPartitions(moduleOp);
+
+    // Create a channel.
     Value channel = createChannelWithGroupInfo(
-        loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
-        op.getUseGlobalDeviceIds(), rewriter);
+        loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+        op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
 
     // Get the collective element type attribute.
     auto resultType = cast<RankedTensorType>(op.getResult().getType());
@@ -635,10 +790,10 @@
 void populateStableHloCollectivesConversionPatterns(
     MLIRContext *context, TypeConverter &typeConverter,
     RewritePatternSet *patterns) {
-  patterns
-      ->add<AllGatherOpConversion, AllReduceOpConversion, AllToAllOpConversion,
-            ReduceScatterOpConversion, ReplicaIdOpConversion>(typeConverter,
-                                                              context);
+  patterns->add<AllGatherOpConversion, AllReduceOpConversion,
+                AllToAllOpConversion, PartitionIdOpConversion,
+                ReduceScatterOpConversion, ReplicaIdOpConversion>(typeConverter,
+                                                                  context);
 }
 
 }  // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
index 89e707a..82870f3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
@@ -14,6 +14,50 @@
 
 // -----
 
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @replica_id_with_partitions
+  func.func @replica_id_with_partitions() -> tensor<ui32> {
+    // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+    // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+    // CHECK-DAG: %[[DIV2:.+]] = arith.divui %[[RANK]], %c2 : index
+    // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[DIV2]] : index to i32
+    // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+    // CHECK-DAG: return %[[TENSOR]] : tensor<i32>
+    %id = stablehlo.replica_id : tensor<ui32>
+    return %id : tensor<ui32>
+  }
+}
+
+// -----
+
+// Returns 0 since num_partitions is not set.
+
+// CHECK-LABEL: @partition_id
+func.func @partition_id() -> tensor<ui32> {
+  // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : tensor<i32>
+  // CHECK-DAG: return %[[CST0]] : tensor<i32>
+  %id = stablehlo.partition_id : tensor<ui32>
+  return %id : tensor<ui32>
+}
+
+// -----
+
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @partition_id_with_partitions
+  func.func @partition_id_with_partitions() -> tensor<ui32> {
+    // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+    // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+    // CHECK-DAG: %[[REM2:.+]] = arith.remui %[[RANK]], %c2 : index
+    // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[REM2]] : index to i32
+    // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+    // CHECK-DAG: return %[[TENSOR]] : tensor<i32>
+    %id = stablehlo.partition_id : tensor<ui32>
+    return %id : tensor<ui32>
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @all_reduce_sum
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
 func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
@@ -345,3 +389,81 @@
       replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x2xf32>) -> tensor<2x2xf32>
   return %out : tensor<2x2xf32>
 }
+
+// -----
+
+// flattened_ids: channel_id > 0 && use_global_device_ids = true
+module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 8 : i32 } {
+  // CHECK-LABEL: @flattened_ids
+  // CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+  func.func @flattened_ids(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+    // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+    // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], [[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+    // CHECK: return [[ALLREDUCE]] : tensor<2304xf32>
+    %out = "stablehlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+        stablehlo.return %sum : tensor<f32>
+      }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+          replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+          use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}
+
+// -----
+
+// cross-replica: channel_id <= 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @cross_replica
+  func.func @cross_replica(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // Cross replica should form groups (0,2,4,6),(1,3,5,7), where each number represents a cell below.
+    // +---+---+
+    // | 0 | 1 |
+    // | 2 | 3 |
+    // | 4 | 5 |
+    // | 6 | 7 |
+    // +---+---+
+    //                          rank:   0    1    2    3    4    5    6    7
+    // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index
+    // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index
+    %out = "stablehlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+        stablehlo.return %sum : tensor<f32>
+      }) {channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
+          replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
+         } : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}
+
+// -----
+
+// cross_replica_and_partition: channel_id > 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+  // CHECK-LABEL: @cross_replica_and_partition
+  func.func @cross_replica_and_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+    // Cross replica_and_partition should form groups (0,2,1,3),(4,6,5,7), where each number represents a cell below.
+    // Note that the rank is assigned in a partiton first, e.g., rank 0 and 1 are assigned to cell 0 and 2, respectively.
+    // +---+---+
+    // | 0   1 |
+    // | 2   3 |
+    // |---+---|
+    // | 4   5 |
+    // | 6   7 |
+    // +---+---+
+    //                          rank:   0    1    2    3    4    5    6    7
+    // CHECK: util.switch index from [%c0, %c0, %c0, %c0, %c1, %c1, %c1, %c1] at %channel_rank else %c-1 : index
+    // CHECK: util.switch index from [%c0, %c2, %c1, %c3, %c0, %c2, %c1, %c3] at %channel_rank else %c-1 : index
+    %out = "stablehlo.all_reduce"(%input) ({
+      ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+        %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+        stablehlo.return %sum : tensor<f32>
+      }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+          replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
+         } : (tensor<2304xf32>) -> tensor<2304xf32>
+    return %out : tensor<2304xf32>
+  }
+}