Add i8 support to rng_bit_generator (#14483)
Fixes Issue #14131, to add i8 support to `stablehlo.rng_bit_generator`
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
index 4372d5f..e7fc85f 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp
@@ -727,10 +727,7 @@
if (bitwidth == 64) {
return generateLinalgThreeFry64(builder, loc, resultTy, state, result);
}
- if (bitwidth == 32) {
- return generateLinalgThreeFry32(builder, loc, resultTy, state, result);
- }
- if (bitwidth == 16) {
+ if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
return generateLinalgThreeFry32(builder, loc, resultTy, state, result);
}
@@ -747,7 +744,7 @@
}
// The 32 bit implementation trancates to result eTy.
- if (bitwidth == 32 || bitwidth == 16) {
+ if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
return generateLinalgPhilox32(builder, loc, resultTy, state, result);
}
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
index 83c96a3..fae5d5d 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir
@@ -326,6 +326,44 @@
// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi16>
+// -----
+
+// CHECK-LABEL: func.func @three_fry_i8
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64>
+func.func @three_fry_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) {
+ %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>)
+ return %output_state, %output : tensor<2xi64>, tensor<8xi8>
+}
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64
+
+// Check we update state correctly:
+// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64>
+// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64
+
+// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi8>
+// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi8>
+// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#map, #map]
+// CHECK-SAME: iterator_types = ["parallel"]}
+// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi8>, tensor<4xi8>)
+
+// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4xi8> into tensor<4x1xi8>
+
+// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4xi8> into tensor<4x1xi8>
+
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi8>
+// CHECK: %[[CONCAT:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi8>)
+
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]]
+// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi8> into tensor<8xi8>
+// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64>
+
+// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi8>
// -----
@@ -613,3 +651,34 @@
// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64>
// CHECK: return %[[INSERTED]], %[[COLLAPSE]]
+
+// -----
+
+func.func @philox_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) {
+ %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo<rng_algorithm PHILOX>} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>)
+ return %output_state, %output : tensor<2xi64>, tensor<8xi8>
+}
+
+// CHECK-LABEL: func.func @philox_i8
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64>
+
+ //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64
+
+// Check we update state correctly:
+// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64>
+// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64
+
+// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi8>
+// CHECK: %[[GENERIC:.+]]:4 = linalg.generic
+// CHECK-SAME: indexing_maps = [#map, #map, #map, #map]
+// CHECK-SAME: iterator_types = ["parallel"]}
+// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi8>, tensor<2xi8>, tensor<2xi8>, tensor<2xi8>)
+
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi8>
+// CHECK: %[[CONCAT:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi8>)