Support `num_replicas` and `num_partitions` (#13288)
Support `num_replicas` and `num_partitions` through cross_replica and
cross_partition for all_reduce, all_gather, and reduce_scatter,
replica_id, and partition_id.
Based on the channel ID and use_global_device_ids, the communication
mode can be different in the 2D grid.
See
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops
for more information.
Note that stablehlo still uses `mhlo.num_replicas` and
`mhlo.num_partitions` to embed the info in the module.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
index 7e71f68..a8604cd 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
@@ -145,10 +145,6 @@
return rewriter.notifyMatchFailure(
op, "must not set use_global_device_ids when channel_id <= 0");
}
- } else {
- if (!op.getUseGlobalDeviceIds()) {
- return rewriter.notifyMatchFailure(op, "must set use_global_device_ids");
- }
}
return success();
@@ -212,40 +208,116 @@
return std::make_pair(color, key);
}
+static DenseIntElementsAttr convertToRankGroupsByCrossReplica(
+ DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+ OpBuilder &builder) {
+ if (numPartitions < 1) {
+ // Treat as a single partition.
+ return replicaGroups;
+ }
+
+ auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+ assert(groupsType.getRank() == 2);
+ int rows = groupsType.getShape()[0];
+ int cols = groupsType.getShape()[1];
+ auto values = replicaGroups.getValues<int64_t>();
+ SmallVector<Attribute> newValues;
+
+ // The number of groups is (rows * numPartitions).
+ for (int i = 0; i < rows; ++i) {
+ for (int p = 0; p < numPartitions; ++p) {
+ // Each group starts here. The group size is the same as the column size.
+ for (int j = 0; j < cols; ++j) {
+ const int index = i * cols + j;
+ const int64_t replicaId = values[index];
+ const int64_t value =
+ (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+ newValues.push_back(builder.getI64IntegerAttr(value));
+ }
+ }
+ }
+
+ auto type =
+ RankedTensorType::get({rows * numPartitions, cols}, builder.getI64Type());
+ return DenseIntElementsAttr::get(type, newValues);
+}
+
+static DenseIntElementsAttr convertToRankGroupsByCrossReplicaAndPartition(
+ DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+ OpBuilder &builder) {
+ if (numPartitions < 1) {
+ // Treat as a single partition.
+ return replicaGroups;
+ }
+
+ auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+ assert(groupsType.getRank() == 2);
+ int rows = groupsType.getShape()[0];
+ int cols = groupsType.getShape()[1];
+ auto values = replicaGroups.getValues<int64_t>();
+ SmallVector<Attribute> newValues;
+
+ // The number of groups is the same as the number of rows.
+ for (int i = 0; i < rows; ++i) {
+ // Each group starts here. The group size is (numPartitions * cols).
+ for (int p = 0; p < numPartitions; ++p) {
+ for (int j = 0; j < cols; ++j) {
+ const int index = i * cols + j;
+ const int64_t replicaId = values[index];
+ const int64_t value =
+ (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+ newValues.push_back(builder.getI64IntegerAttr(value));
+ }
+ }
+ }
+ auto type =
+ RankedTensorType::get({rows, numPartitions * cols}, builder.getI64Type());
+ return DenseIntElementsAttr::get(type, newValues);
+}
+
/// Creates a channel matching the given |channelHandleAttr| scoped to the
/// requested group.
static Value createChannelWithGroupInfo(
Location loc, mhlo::ChannelHandleAttr channelHandleAttr,
+ int32_t numReplicas, int32_t numPartitions,
DenseIntElementsAttr replicaGroups, bool useGlobalDeviceIds,
OpBuilder &builder) {
+ // Set numPartitions to 1 if not set by the user.
+ if (numPartitions == -1) numPartitions = 1;
+
// Base channel that may be split by the group info.
Value baseChannel =
builder.create<IREE::Flow::ChannelDefaultOp>(loc, /*group=*/StringAttr{});
- // TODO(okkwon): Convert replica_groups into flattened IDs.
- //
- // Once mhlo exposes `num_replicas` and `num_partitions`,
- // use the channel ID to determine the collective operation mode, such as
- // cross_replica, cross_partition, cross_replic_and_partition, and
- // flattend_ids. Currently, we only supports the flanttend_ids mode.
- //
- // int64_t channelId = 0;
- // if (channelHandleAttr) {
- // channelId = channelHandleAttr.getHandle();
- // }
-
// No need to split if there is a single group.
ShapedType replicaGroupType = replicaGroups.getType();
assert(replicaGroupType.getRank() == 2);
- if (replicaGroupType.getDimSize(0) == 1) {
+ if (numPartitions == 1 && replicaGroupType.getDimSize(0) == 1) {
return baseChannel;
}
+ // Convert replica_groups into flattened IDs.
+ DenseIntElementsAttr rankGroups;
+ int64_t channelId = channelHandleAttr ? channelHandleAttr.getHandle() : 0;
+ if (channelId <= 0) {
+ assert(!useGlobalDeviceIds);
+ rankGroups = convertToRankGroupsByCrossReplica(replicaGroups, numPartitions,
+ builder);
+ } else {
+ if (useGlobalDeviceIds) {
+ // already flattened.
+ rankGroups = replicaGroups;
+ } else {
+ rankGroups = convertToRankGroupsByCrossReplicaAndPartition(
+ replicaGroups, numPartitions, builder);
+ }
+ }
+
// Construct lookups for color and key split parameters.
// Note that `replica_groups` can be interpreted in multiple ways based on the
// other attributes.
auto [color, key] =
- makeSplitColorAndKey(loc, baseChannel, replicaGroups, builder);
+ makeSplitColorAndKey(loc, baseChannel, rankGroups, builder);
// Split the channel. Note that this is an expensive operation.
return builder.create<IREE::Flow::ChannelSplitOp>(loc, baseChannel, color,
@@ -268,11 +340,69 @@
permutationAttr);
}
+static int32_t getNumReplicas(ModuleOp moduleOp) {
+ if (!moduleOp) {
+ return -1;
+ }
+ if (auto numReplicasAttr =
+ moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_replicas")) {
+ return numReplicasAttr.getInt();
+ } else {
+ return -1;
+ }
+}
+
+static int32_t getNumPartitions(ModuleOp moduleOp) {
+ if (!moduleOp) {
+ return -1;
+ }
+ if (auto numPartitionsAttr =
+ moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_partitions")) {
+ return numPartitionsAttr.getInt();
+ } else {
+ return -1;
+ }
+}
+
} // namespace
-/// Converts mhlo.replica_id to flow.channel.default + flow.channel.rank.
-/// TODO(okkwon): this assumes that there is no partition so that there is a 1:1
-/// mapping between the replica ID and the process ID.
+/// Converts mhlo.partition_id to (flow.channel.rank % numPartitions)
+struct PartitionIdOpConversion
+ : public OpConversionPattern<mhlo::PartitionIdOp> {
+ using OpConversionPattern<mhlo::PartitionIdOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::PartitionIdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ // PartitionId = rank % numPartitions
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numPartitions = getNumPartitions(moduleOp);
+ Value value;
+ if (numPartitions <= 1) {
+ value = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ } else {
+ auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
+ loc, /*group=*/StringAttr{});
+ Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+ auto cst =
+ rewriter.create<arith::ConstantIndexOp>(loc,
+ /*value=*/numPartitions);
+ value = rewriter.create<arith::RemUIOp>(loc, rank, cst);
+ }
+ auto resultType = op.getType().cast<RankedTensorType>(); // tensor<ui32>
+ auto elemType = resultType.getElementType();
+ // index -> ui32
+ auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, value);
+ // tensor<ui32>
+ auto rankTensor = rewriter.create<tensor::FromElementsOp>(
+ loc, resultType, rankElem.getResult());
+ rewriter.replaceOp(op, rankTensor.getResult());
+ return success();
+ }
+};
+
+/// Converts mhlo.replica_id to floor_div(flow.channel.rank, numPartitions)
struct ReplicaIdOpConversion : public OpConversionPattern<mhlo::ReplicaIdOp> {
using OpConversionPattern<mhlo::ReplicaIdOp>::OpConversionPattern;
@@ -282,7 +412,17 @@
auto loc = op.getLoc();
auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
loc, /*group=*/StringAttr{});
- auto rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+ Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+
+ // ReplicaId = floor_div(rank, numPartitions)
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numPartitions = getNumPartitions(moduleOp);
+ auto cst = rewriter.create<arith::ConstantIndexOp>(loc,
+ /*value=*/numPartitions);
+ if (numPartitions > 1) {
+ rank = rewriter.create<arith::DivUIOp>(loc, rank, cst);
+ }
+
auto resultType = op.getType().cast<RankedTensorType>(); // tensor<ui32>
auto elemType = resultType.getElementType();
// index -> ui32
@@ -307,10 +447,14 @@
auto loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Get the collective element type attribute.
auto resultType = op.getResult().getType().cast<RankedTensorType>();
@@ -392,10 +536,14 @@
auto loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Convert mhlo reduction op into flow reduction op.
auto reductionOpAttr =
@@ -591,10 +739,14 @@
auto loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Get the collective element type attribute.
auto resultType = op.getResult().getType().cast<RankedTensorType>();
@@ -652,6 +804,7 @@
patterns.insert<AllGatherOpConversion>(typeConverter, context);
patterns.insert<AllReduceOpConversion>(typeConverter, context);
patterns.insert<AllToAllOpConversion>(typeConverter, context);
+ patterns.insert<PartitionIdOpConversion>(typeConverter, context);
patterns.insert<ReduceScatterOpConversion>(typeConverter, context);
patterns.insert<ReplicaIdOpConversion>(typeConverter, context);
}
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index 2f2e482..e8ab948 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -14,6 +14,53 @@
// -----
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @replica_id_with_partitions
+ func.func @replica_id_with_partitions() -> tensor<ui32> {
+ // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+ // CHECK-DAG: %[[DIV2:.+]] = arith.divui %[[RANK]], %c2 : index
+ // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[DIV2]] : index to i32
+ // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+ // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[TENSOR]] : tensor<i32> to tensor<ui32>
+ // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+ %id = mhlo.replica_id : tensor<ui32>
+ return %id : tensor<ui32>
+ }
+}
+
+// -----
+
+// Returns 0 since num_partitions is not set.
+
+// CHECK-LABEL: @partition_id
+func.func @partition_id() -> tensor<ui32> {
+ // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : tensor<i32>
+ // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[CST0]] : tensor<i32> to tensor<ui32>
+ // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+ %id = mhlo.partition_id : tensor<ui32>
+ return %id : tensor<ui32>
+}
+
+// -----
+
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @partition_id_with_partitions
+ func.func @partition_id_with_partitions() -> tensor<ui32> {
+ // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+ // CHECK-DAG: %[[REM2:.+]] = arith.remui %[[RANK]], %c2 : index
+ // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[REM2]] : index to i32
+ // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+ // CHECK-DAG: %[[BITCAST:.+]] = tensor.bitcast %[[TENSOR]] : tensor<i32> to tensor<ui32>
+ // CHECK-DAG: return %[[BITCAST]] : tensor<ui32>
+ %id = mhlo.partition_id : tensor<ui32>
+ return %id : tensor<ui32>
+ }
+}
+
+// -----
+
// CHECK-LABEL: @all_reduce_sum
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
@@ -345,3 +392,81 @@
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x2xf32>) -> tensor<2x2xf32>
return %out : tensor<2x2xf32>
}
+
+// -----
+
+// flattened_ids: channel_id > 0 && use_global_device_ids = true
+module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 8 : i32 } {
+ // CHECK-LABEL: @flattened_ids
+ // CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+ func.func @flattened_ids(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+ // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+ // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], [[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+ // CHECK: return [[ALLREDUCE]] : tensor<2304xf32>
+ %out = "mhlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+ mhlo.return %sum : tensor<f32>
+ }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+ use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}
+
+// -----
+
+// cross-replica: channel_id <= 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @cross_replica
+ func.func @cross_replica(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // Cross replica should form groups (0,2,4,6),(1,3,5,7), where each number represents a cell below.
+ // +---+---+
+ // | 0 | 1 |
+ // | 2 | 3 |
+ // | 4 | 5 |
+ // | 6 | 7 |
+ // +---+---+
+ // rank: 0 1 2 3 4 5 6 7
+ // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index
+ // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index
+ %out = "mhlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+ mhlo.return %sum : tensor<f32>
+ }) {channel_handle = #mhlo.channel_handle<handle = 0, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
+ } : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}
+
+// -----
+
+// cross_replica_and_partition: channel_id > 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @cross_replica_and_partition
+ func.func @cross_replica_and_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // Cross replica_and_partition should form groups (0,2,1,3),(4,6,5,7), where each number represents a cell below.
+ // Note that the rank is assigned in a partiton first, e.g., rank 0 and 1 are assigned to cell 0 and 2, respectively.
+ // +---+---+
+ // | 0 1 |
+ // | 2 3 |
+ // |---+---|
+ // | 4 5 |
+ // | 6 7 |
+ // +---+---+
+ // rank: 0 1 2 3 4 5 6 7
+ // CHECK: util.switch index from [%c0, %c0, %c0, %c0, %c1, %c1, %c1, %c1] at %channel_rank else %c-1 : index
+ // CHECK: util.switch index from [%c0, %c2, %c1, %c3, %c0, %c2, %c1, %c3] at %channel_rank else %c-1 : index
+ %out = "mhlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = mhlo.add %arg0, %arg1 : tensor<f32>
+ mhlo.return %sum : tensor<f32>
+ }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
+ } : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
index f277519..a1474ae 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
@@ -139,8 +139,6 @@
return rewriter.notifyMatchFailure(
op, "must not set use_global_device_ids when channel_id <= 0");
}
- } else if (!op.getUseGlobalDeviceIds()) {
- return rewriter.notifyMatchFailure(op, "must set use_global_device_ids");
}
return success();
@@ -204,46 +202,146 @@
return std::make_pair(color, key);
}
+static DenseIntElementsAttr convertToRankGroupsByCrossReplica(
+ DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+ OpBuilder &builder) {
+ if (numPartitions < 1) {
+ // Treat as a single partition.
+ return replicaGroups;
+ }
+
+ auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+ assert(groupsType.getRank() == 2);
+ int rows = groupsType.getShape()[0];
+ int cols = groupsType.getShape()[1];
+ auto values = replicaGroups.getValues<int64_t>();
+ SmallVector<Attribute> newValues;
+
+ // The number of groups is (rows * numPartitions).
+ for (int i = 0; i < rows; ++i) {
+ for (int p = 0; p < numPartitions; ++p) {
+ // Each group starts here. The group size is the same as the column size.
+ for (int j = 0; j < cols; ++j) {
+ const int index = i * cols + j;
+ const int64_t replicaId = values[index];
+ const int64_t value =
+ (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+ newValues.push_back(builder.getI64IntegerAttr(value));
+ }
+ }
+ }
+
+ auto type =
+ RankedTensorType::get({rows * numPartitions, cols}, builder.getI64Type());
+ return DenseIntElementsAttr::get(type, newValues);
+}
+
+static DenseIntElementsAttr convertToRankGroupsByCrossReplicaAndPartition(
+ DenseIntElementsAttr replicaGroups, int32_t numPartitions,
+ OpBuilder &builder) {
+ if (numPartitions < 1) {
+ // Treat as a single partition.
+ return replicaGroups;
+ }
+
+ auto groupsType = replicaGroups.getType().cast<RankedTensorType>();
+ assert(groupsType.getRank() == 2);
+ int rows = groupsType.getShape()[0];
+ int cols = groupsType.getShape()[1];
+ auto values = replicaGroups.getValues<int64_t>();
+ SmallVector<Attribute> newValues;
+
+ // The number of groups is the same as the number of rows.
+ for (int i = 0; i < rows; ++i) {
+ // Each group starts here. The group size is (numPartitions * cols).
+ for (int p = 0; p < numPartitions; ++p) {
+ for (int j = 0; j < cols; ++j) {
+ const int index = i * cols + j;
+ const int64_t replicaId = values[index];
+ const int64_t value =
+ (replicaId == -1) ? -1 : replicaId * numPartitions + p;
+ newValues.push_back(builder.getI64IntegerAttr(value));
+ }
+ }
+ }
+ auto type =
+ RankedTensorType::get({rows, numPartitions * cols}, builder.getI64Type());
+ return DenseIntElementsAttr::get(type, newValues);
+}
+
/// Creates a channel matching the given |channelHandleAttr| scoped to the
/// requested group.
static Value createChannelWithGroupInfo(
Location loc, mlir::stablehlo::ChannelHandleAttr channelHandleAttr,
+ int32_t numReplicas, int32_t numPartitions,
DenseIntElementsAttr replicaGroups, bool useGlobalDeviceIds,
OpBuilder &builder) {
+ // Set numPartitions to 1 if not set by the user.
+ if (numPartitions == -1) numPartitions = 1;
+
// Base channel that may be split by the group info.
Value baseChannel =
builder.create<IREE::Flow::ChannelDefaultOp>(loc, /*group=*/StringAttr{});
- // TODO(okkwon): Convert replica_groups into flattened IDs.
- //
- // Once stablehlo exposes `num_replicas` and `num_partitions`,
- // use the channel ID to determine the collective operation mode, such as
- // cross_replica, cross_partition, cross_replic_and_partition, and
- // flattend_ids. Currently, we only supports the flanttend_ids mode.
- //
- // int64_t channelId = 0;
- // if (channelHandleAttr) {
- // channelId = channelHandleAttr.getHandle();
- // }
-
// No need to split if there is a single group.
ShapedType replicaGroupType = replicaGroups.getType();
assert(replicaGroupType.getRank() == 2);
- if (replicaGroupType.getDimSize(0) == 1) {
+ if (numPartitions == 1 && replicaGroupType.getDimSize(0) == 1) {
return baseChannel;
}
+ // Convert replica_groups into flattened IDs.
+ DenseIntElementsAttr rankGroups;
+ int64_t channelId = channelHandleAttr ? channelHandleAttr.getHandle() : 0;
+ if (channelId <= 0) {
+ assert(!useGlobalDeviceIds);
+ rankGroups = convertToRankGroupsByCrossReplica(replicaGroups, numPartitions,
+ builder);
+ } else {
+ if (useGlobalDeviceIds) {
+ // already flattened.
+ rankGroups = replicaGroups;
+ } else {
+ rankGroups = convertToRankGroupsByCrossReplicaAndPartition(
+ replicaGroups, numPartitions, builder);
+ }
+ }
+
// Construct lookups for color and key split parameters.
// Note that `replica_groups` can be interpreted in multiple ways based on the
// other attributes.
auto [color, key] =
- makeSplitColorAndKey(loc, baseChannel, replicaGroups, builder);
+ makeSplitColorAndKey(loc, baseChannel, rankGroups, builder);
// Split the channel. Note that this is an expensive operation.
return builder.create<IREE::Flow::ChannelSplitOp>(loc, baseChannel, color,
key);
}
+static int32_t getNumReplicas(ModuleOp moduleOp) {
+ if (!moduleOp) {
+ return -1;
+ }
+ if (auto numReplicasAttr =
+ moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_replicas")) {
+ return numReplicasAttr.getInt();
+ } else {
+ return -1;
+ }
+}
+
+static int32_t getNumPartitions(ModuleOp moduleOp) {
+ if (!moduleOp) {
+ return -1;
+ }
+ if (auto numPartitionsAttr =
+ moduleOp->getAttrOfType<IntegerAttr>("mhlo.num_partitions")) {
+ return numPartitionsAttr.getInt();
+ } else {
+ return -1;
+ }
+}
+
static Value emitTranspose(ConversionPatternRewriter &rewriter, Location loc,
Value input, int64_t srcDim, int64_t dstDim) {
// Creates a transpose op that swaps dimensions srcDim and dstDim in the
@@ -260,22 +358,67 @@
permutationAttr);
}
-/// Converts stablehlo.replica_id to flow.channel.default + flow.channel.rank.
-/// TODO(okkwon): this assumes that there is no partition so that there is a 1:1
-/// mapping between the replica ID and the process ID.
-struct ReplicaIdOpConversion final
- : OpConversionPattern<mlir::stablehlo::ReplicaIdOp> {
- using OpConversionPattern::OpConversionPattern;
+/// Converts stablehlo.partition_id to (flow.channel.rank % numPartitions)
+struct PartitionIdOpConversion
+ : public OpConversionPattern<mlir::stablehlo::PartitionIdOp> {
+ using OpConversionPattern<
+ mlir::stablehlo::PartitionIdOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::stablehlo::PartitionIdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ // PartitionId = rank % numPartitions
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numPartitions = getNumPartitions(moduleOp);
+ Value value;
+ if (numPartitions <= 1) {
+ value = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ } else {
+ auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
+ loc, /*group=*/StringAttr{});
+ Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+ auto cst =
+ rewriter.create<arith::ConstantIndexOp>(loc,
+ /*value=*/numPartitions);
+ value = rewriter.create<arith::RemUIOp>(loc, rank, cst);
+ }
+ auto resultType = op.getType().cast<RankedTensorType>(); // tensor<ui32>
+ auto elemType = resultType.getElementType();
+ // index -> ui32
+ auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, value);
+ // tensor<ui32>
+ auto rankTensor = rewriter.create<tensor::FromElementsOp>(
+ loc, resultType, rankElem.getResult());
+ rewriter.replaceOp(op, rankTensor.getResult());
+ return success();
+ }
+};
+
+/// Converts stablehlo.replica_id to floor_div(flow.channel.rank, numPartitions)
+struct ReplicaIdOpConversion
+ : public OpConversionPattern<mlir::stablehlo::ReplicaIdOp> {
+ using OpConversionPattern<mlir::stablehlo::ReplicaIdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::stablehlo::ReplicaIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
+ auto loc = op.getLoc();
auto channel = rewriter.create<IREE::Flow::ChannelDefaultOp>(
loc, /*group=*/StringAttr{});
- auto rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
- auto resultType = cast<RankedTensorType>(op.getType()); // tensor<ui32>
- Type elemType = resultType.getElementType();
+ Value rank = rewriter.create<IREE::Flow::ChannelRankOp>(loc, channel);
+
+ // ReplicaId = floor_div(rank, numPartitions)
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numPartitions = getNumPartitions(moduleOp);
+ auto cst = rewriter.create<arith::ConstantIndexOp>(loc,
+ /*value=*/numPartitions);
+ if (numPartitions > 1) {
+ rank = rewriter.create<arith::DivUIOp>(loc, rank, cst);
+ }
+
+ auto resultType = op.getType().cast<RankedTensorType>(); // tensor<ui32>
+ auto elemType = resultType.getElementType();
// index -> ui32
auto rankElem = rewriter.create<arith::IndexCastUIOp>(loc, elemType, rank);
// tensor<ui32>
@@ -299,10 +442,14 @@
Location loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getResult().getType());
@@ -385,10 +532,14 @@
Location loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Convert stablehlo reduction op into flow reduction op.
auto reductionOpAttr =
@@ -575,10 +726,14 @@
Location loc = op.getLoc();
- // Get the channel used for communication.
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ int32_t numReplicas = getNumReplicas(moduleOp);
+ int32_t numPartitions = getNumPartitions(moduleOp);
+
+ // Create a channel.
Value channel = createChannelWithGroupInfo(
- loc, op.getChannelHandleAttr(), op.getReplicaGroups(),
- op.getUseGlobalDeviceIds(), rewriter);
+ loc, op.getChannelHandleAttr(), numReplicas, numPartitions,
+ op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getResult().getType());
@@ -635,10 +790,10 @@
void populateStableHloCollectivesConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns) {
- patterns
- ->add<AllGatherOpConversion, AllReduceOpConversion, AllToAllOpConversion,
- ReduceScatterOpConversion, ReplicaIdOpConversion>(typeConverter,
- context);
+ patterns->add<AllGatherOpConversion, AllReduceOpConversion,
+ AllToAllOpConversion, PartitionIdOpConversion,
+ ReduceScatterOpConversion, ReplicaIdOpConversion>(typeConverter,
+ context);
}
} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
index 89e707a..82870f3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
@@ -14,6 +14,50 @@
// -----
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @replica_id_with_partitions
+ func.func @replica_id_with_partitions() -> tensor<ui32> {
+ // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+ // CHECK-DAG: %[[DIV2:.+]] = arith.divui %[[RANK]], %c2 : index
+ // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[DIV2]] : index to i32
+ // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+ // CHECK-DAG: return %[[TENSOR]] : tensor<i32>
+ %id = stablehlo.replica_id : tensor<ui32>
+ return %id : tensor<ui32>
+ }
+}
+
+// -----
+
+// Returns 0 since num_partitions is not set.
+
+// CHECK-LABEL: @partition_id
+func.func @partition_id() -> tensor<ui32> {
+ // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : tensor<i32>
+ // CHECK-DAG: return %[[CST0]] : tensor<i32>
+ %id = stablehlo.partition_id : tensor<ui32>
+ return %id : tensor<ui32>
+}
+
+// -----
+
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @partition_id_with_partitions
+ func.func @partition_id_with_partitions() -> tensor<ui32> {
+ // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index
+ // CHECK-DAG: %[[REM2:.+]] = arith.remui %[[RANK]], %c2 : index
+ // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[REM2]] : index to i32
+ // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor<i32>
+ // CHECK-DAG: return %[[TENSOR]] : tensor<i32>
+ %id = stablehlo.partition_id : tensor<ui32>
+ return %id : tensor<ui32>
+ }
+}
+
+// -----
+
// CHECK-LABEL: @all_reduce_sum
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> {
@@ -345,3 +389,81 @@
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x2xf32>) -> tensor<2x2xf32>
return %out : tensor<2x2xf32>
}
+
+// -----
+
+// flattened_ids: channel_id > 0 && use_global_device_ids = true
+module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 8 : i32 } {
+ // CHECK-LABEL: @flattened_ids
+ // CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>)
+ func.func @flattened_ids(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel
+ // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32>
+ // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], [[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32>
+ // CHECK: return [[ALLREDUCE]] : tensor<2304xf32>
+ %out = "stablehlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+ stablehlo.return %sum : tensor<f32>
+ }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+ use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}
+
+// -----
+
+// cross-replica: channel_id <= 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @cross_replica
+ func.func @cross_replica(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // Cross replica should form groups (0,2,4,6),(1,3,5,7), where each number represents a cell below.
+ // +---+---+
+ // | 0 | 1 |
+ // | 2 | 3 |
+ // | 4 | 5 |
+ // | 6 | 7 |
+ // +---+---+
+ // rank: 0 1 2 3 4 5 6 7
+ // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index
+ // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index
+ %out = "stablehlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+ stablehlo.return %sum : tensor<f32>
+ }) {channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
+ } : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}
+
+// -----
+
+// cross_replica_and_partition: channel_id > 0 && use_global_device_ids = false
+module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } {
+ // CHECK-LABEL: @cross_replica_and_partition
+ func.func @cross_replica_and_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> {
+ // Cross replica_and_partition should form groups (0,2,1,3),(4,6,5,7), where each number represents a cell below.
+ // Note that the rank is assigned in a partiton first, e.g., rank 0 and 1 are assigned to cell 0 and 2, respectively.
+ // +---+---+
+ // | 0 1 |
+ // | 2 3 |
+ // |---+---|
+ // | 4 5 |
+ // | 6 7 |
+ // +---+---+
+ // rank: 0 1 2 3 4 5 6 7
+ // CHECK: util.switch index from [%c0, %c0, %c0, %c0, %c1, %c1, %c1, %c1] at %channel_rank else %c-1 : index
+ // CHECK: util.switch index from [%c0, %c2, %c1, %c3, %c0, %c2, %c1, %c3] at %channel_rank else %c-1 : index
+ %out = "stablehlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
+ %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
+ stablehlo.return %sum : tensor<f32>
+ }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>
+ } : (tensor<2304xf32>) -> tensor<2304xf32>
+ return %out : tensor<2304xf32>
+ }
+}