Fix mhlo.all_reduce and stablehlo.reduce for uint (#13899)

All reduce lowerings are written asumming no type conversion. Made sure
to use the adaptor values instead of the original operands.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
index 7910b5a..b294a7c 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp
@@ -656,11 +656,11 @@
 
     // Create an empty tensor for the result.
     ArrayRef<int64_t> inputShape = inputType.getShape();
-    Value target = rewriter.create<tensor::EmptyOp>(loc, inputShape,
-                                                    inputType.getElementType());
+    Value target = rewriter.create<tensor::EmptyOp>(
+        loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
     auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
-        op.getLoc(), reductionOpAttr, elementTypeAttr, target, op.getOperand(),
-        channel);
+        op.getLoc(), reductionOpAttr, elementTypeAttr, target,
+        adaptor.getOperand(), channel);
     rewriter.replaceOp(op, allReduceOp.getResult());
     return success();
   }
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index 63fe03c..9098b7c 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -76,6 +76,25 @@
 
 // -----
 
+// CHECK-LABEL: @all_reduce_sum_uint
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xi32>
+func.func @all_reduce_sum_uint(%input : tensor<2304xui32>) -> tensor<2304xui32> {
+  // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xi32>
+  // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]]  : (tensor<2304xi32>, tensor<2304xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xi32>
+  // CHECK: return %[[OP]] : tensor<2304xi32>
+  %out = "mhlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<ui32>, %arg1: tensor<ui32>):
+      %sum = mhlo.add %arg0, %arg1 : tensor<ui32>
+      mhlo.return %sum : tensor<ui32>
+    }) {channel_handle = #mhlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xui32>) -> tensor<2304xui32>
+  return %out : tensor<2304xui32>
+}
+
+// -----
+
 // CHECK-LABEL: @all_reduce_product
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
 func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> {
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
index 2796ac6..6367621 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/ConvertCollectives.cpp
@@ -651,11 +651,11 @@
 
     // Create an empty tensor for the result.
     ArrayRef<int64_t> inputShape = inputType.getShape();
-    Value target = rewriter.create<tensor::EmptyOp>(loc, inputShape,
-                                                    inputType.getElementType());
+    Value target = rewriter.create<tensor::EmptyOp>(
+        loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
     auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
-        op.getLoc(), reductionOpAttr, elementTypeAttr, target, op.getOperand(),
-        channel);
+        op.getLoc(), reductionOpAttr, elementTypeAttr, target,
+        adaptor.getOperand(), channel);
     rewriter.replaceOp(op, allReduceOp.getResult());
     return success();
   }
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
index 863c58f..4e92d69 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/convert_collectives.mlir
@@ -77,6 +77,25 @@
 
 // -----
 
+// CHECK-LABEL: @all_reduce_sum_uint
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xi32>
+func.func @all_reduce_sum_uint(%input : tensor<2304xui32>) -> tensor<2304xui32> {
+  // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xi32>
+  // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]]  : (tensor<2304xi32>, tensor<2304xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xi32>
+  // CHECK: return %[[OP]] : tensor<2304xi32>
+  %out = "stablehlo.all_reduce"(%input) ({
+    ^bb0(%arg0: tensor<ui32>, %arg1: tensor<ui32>):
+      %sum = stablehlo.add %arg0, %arg1 : tensor<ui32>
+      stablehlo.return %sum : tensor<ui32>
+    }) {channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>,
+        replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>,
+        use_global_device_ids} : (tensor<2304xui32>) -> tensor<2304xui32>
+  return %out : tensor<2304xui32>
+}
+
+// -----
+
 // CHECK-LABEL: @all_reduce_product
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>)
 func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> {