Fix mhlo.all_reduce and stablehlo.reduce for uint (#13899)
All reduce lowerings are written asumming no type conversion. Made sure
to use the adaptor values instead of the original operands.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
index 7910b5a..b294a7c 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
@@ -656,11 +656,11 @@
// Create an empty tensor for the result.
ArrayRef<int64_t> inputShape = inputType.getShape();
- Value target = rewriter.create<tensor::EmptyOp>(loc, inputShape,
- inputType.getElementType());
+ Value target = rewriter.create<tensor::EmptyOp>(
+ loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
- op.getLoc(), reductionOpAttr, elementTypeAttr, target, op.getOperand(),
- channel);
+ op.getLoc(), reductionOpAttr, elementTypeAttr, target,
+ adaptor.getOperand(), channel);
rewriter.replaceOp(op, allReduceOp.getResult());
return success();
}
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index 63fe03c..9098b7c 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -76,6 +76,25 @@
// -----
+// CHECK-LABEL: @all_reduce_sum_uint
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xi32>
+func.func @all_reduce_sum_uint(%input : tensor<2304xui32>) -> tensor<2304xui32> {
+ // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xi32>
+ // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xi32>, tensor<2304xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xi32>
+ // CHECK: return %[[OP]] : tensor<2304xi32>
+ %out = "mhlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<ui32>, %arg1: tensor<ui32>):
+ %sum = mhlo.add %arg0, %arg1 : tensor<ui32>
+ mhlo.return %sum : tensor<ui32>
+ }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+ use_global_device_ids} : (tensor<2304xui32>) -> tensor<2304xui32>
+ return %out : tensor<2304xui32>
+}
+
+// -----
+
// CHECK-LABEL: @all_reduce_product
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> {
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
index 2796ac6..6367621 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
@@ -651,11 +651,11 @@
// Create an empty tensor for the result.
ArrayRef<int64_t> inputShape = inputType.getShape();
- Value target = rewriter.create<tensor::EmptyOp>(loc, inputShape,
- inputType.getElementType());
+ Value target = rewriter.create<tensor::EmptyOp>(
+ loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
- op.getLoc(), reductionOpAttr, elementTypeAttr, target, op.getOperand(),
- channel);
+ op.getLoc(), reductionOpAttr, elementTypeAttr, target,
+ adaptor.getOperand(), channel);
rewriter.replaceOp(op, allReduceOp.getResult());
return success();
}
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
index 863c58f..4e92d69 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
@@ -77,6 +77,25 @@
// -----
+// CHECK-LABEL: @all_reduce_sum_uint
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xi32>
+func.func @all_reduce_sum_uint(%input : tensor<2304xui32>) -> tensor<2304xui32> {
+ // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xi32>
+ // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xi32>, tensor<2304xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xi32>
+ // CHECK: return %[[OP]] : tensor<2304xi32>
+ %out = "stablehlo.all_reduce"(%input) ({
+ ^bb0(%arg0: tensor<ui32>, %arg1: tensor<ui32>):
+ %sum = stablehlo.add %arg0, %arg1 : tensor<ui32>
+ stablehlo.return %sum : tensor<ui32>
+ }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+ replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+ use_global_device_ids} : (tensor<2304xui32>) -> tensor<2304xui32>
+ return %out : tensor<2304xui32>
+}
+
+// -----
+
// CHECK-LABEL: @all_reduce_product
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> {