Add i8 support to rng_bit_generator (#14483)

Fixes Issue #14131, to add i8 support to `stablehlo.rng_bit_generator`
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
index 4372d5f..e7fc85f 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
@@ -727,10 +727,7 @@
   if (bitwidth == 64) {
     return generateLinalgThreeFry64(builder, loc, resultTy, state, result);
   }
-  if (bitwidth == 32) {
-    return generateLinalgThreeFry32(builder, loc, resultTy, state, result);
-  }
-  if (bitwidth == 16) {
+  if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
     return generateLinalgThreeFry32(builder, loc, resultTy, state, result);
   }
 
@@ -747,7 +744,7 @@
   }
 
   // The 32 bit implementation trancates to result eTy.
-  if (bitwidth == 32 || bitwidth == 16) {
+  if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
     return generateLinalgPhilox32(builder, loc, resultTy, state, result);
   }
 
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
index 83c96a3..fae5d5d 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
@@ -326,6 +326,44 @@
 
 // CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi16>
 
+// -----
+
+// CHECK-LABEL: func.func @three_fry_i8
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2xi64>
+func.func @three_fry_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) {
+  %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>)
+  return %output_state, %output : tensor<2xi64>, tensor<8xi8>
+}
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64
+
+// Check we update state correctly:
+// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64>
+// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64
+
+// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi8>
+// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi8>
+// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#map, #map]
+// CHECK-SAME: iterator_types = ["parallel"]}
+// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi8>, tensor<4xi8>)
+
+// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4xi8> into tensor<4x1xi8>
+
+// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4xi8> into tensor<4x1xi8>
+
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi8>
+// CHECK: %[[CONCAT:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi8>)
+
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]]
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi8> into tensor<8xi8>
+// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64>
+
+// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi8>
 
 // -----
 
@@ -613,3 +651,34 @@
 // CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64>
 
 // CHECK: return %[[INSERTED]], %[[COLLAPSE]]
+
+// -----
+
+func.func @philox_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) {
+  %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo<rng_algorithm PHILOX>} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>)
+  return %output_state, %output : tensor<2xi64>, tensor<8xi8>
+}
+
+// CHECK-LABEL: func.func @philox_i8
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2xi64>
+
+ //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64
+
+// Check we update state correctly:
+// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64>
+// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64
+
+// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[GENERIC:.+]]:4 = linalg.generic
+// CHECK-SAME: indexing_maps = [#map, #map, #map, #map]
+// CHECK-SAME: iterator_types = ["parallel"]}
+// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi8>, tensor<2xi8>, tensor<2xi8>, tensor<2xi8>)
+
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi8>
+// CHECK: %[[CONCAT:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi8>)