[StableHLO] Port missing pointwise op tests (#13052)

I went over the list of all pointwise ops that we currently handle.
These fell through the cracks during the initial import (#12957).

Two ops are not tested: `ComplexOp` and `NotOp`, but these do not have
any tests in mlir-hlo either.

Next, I plan to split the linalg lowering tests into a few smaller test
files.

Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg.mlir
index 1da754c..223916d 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg.mlir
@@ -604,6 +604,16 @@
 
 // -----
 
+// CHECK-LABEL: unsigned_convert
+func.func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> {
+  // CHECK: linalg.generic
+  // CHECK: arith.extui
+  %0 = "stablehlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64>
+  func.return %0 : tensor<2x2xui64>
+}
+
+// -----
+
 // CHECK-LABEL: func @float_cmp
 // CHECK-PRIMITIVE-LABEL: func @float_cmp
 func.func @float_cmp(%lhs: tensor<2x2xf32>,
@@ -1098,6 +1108,117 @@
 
 // -----
 
+// CHECK-LABEL: signed_divide
+func.func @signed_divide(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  // CHECK-DAG:   %[[VAL_7:.*]] = arith.constant -1 : i32
+  // CHECK-DAG:   %[[VAL_8:.*]] = arith.constant -2147483648 : i32
+  // CHECK-DAG:   %[[VAL_9:.*]] = arith.constant 0 : i32
+  // CHECK-DAG:   %[[VAL_10:.*]] = arith.constant 1 : i32
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32):
+  // CHECK:   %[[VAL_11:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_9]] : i32
+  // CHECK:   %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_8]] : i32
+  // CHECK:   %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_7]] : i32
+  // CHECK:   %[[VAL_16:.*]] = arith.andi %[[VAL_13]], %[[VAL_15]] : i1
+  // CHECK:   %[[VAL_17:.*]] = arith.ori %[[VAL_11]], %[[VAL_16]] : i1
+  // CHECK:   %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_10]], %[[VAL_5]] : i32
+  // CHECK:   %[[VAL_19:.*]] = arith.divsi %[[VAL_4]], %[[VAL_18]] : i32
+  // CHECK:   %[[VAL_20:.*]] = arith.select %[[VAL_16]], %[[VAL_8]], %[[VAL_19]] : i32
+  // CHECK:   %[[VAL_21:.*]] = arith.select %[[VAL_11]], %[[VAL_7]], %[[VAL_20]] : i32
+  // CHECK:   linalg.yield %[[VAL_21]] : i32
+  %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %0 : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: unsigned_divide
+func.func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
+  // CHECK-DAG:   %[[VAL_9:.*]] = arith.constant -1 : i32
+  // CHECK-DAG:   %[[VAL_11:.*]] = arith.constant 0 : i32
+  // CHECK-DAG:   %[[VAL_12:.*]] = arith.constant 1 : i32
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32):
+  // CHECK:   %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_11]] : i32
+  // CHECK:   %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_12]], %[[VAL_7]] : i32
+  // CHECK:   %[[VAL_15:.*]] = arith.divui %[[VAL_6]], %[[VAL_14]] : i32
+  // CHECK:   %[[VAL_16:.*]] = arith.select %[[VAL_13]], %[[VAL_9]], %[[VAL_15]] : i32
+  // CHECK:   linalg.yield %[[VAL_16]] : i32
+  %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
+  func.return %0 : tensor<2x2xui32>
+}
+
+// -----
+
+// CHECK-LABEL: complex_divide
+func.func @complex_divide(%lhs: tensor<2xcomplex<f32>>,
+                     %rhs: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
+  // CHECK: linalg.generic
+  // CHECK: complex.div
+  %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
+  func.return %0 : tensor<2xcomplex<f32>>
+}
+
+// -----
+
+func.func @shift_left(%lhs: tensor<2x2xi32>,
+                 %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "stablehlo.shift_left"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_left
+// CHECK-DAG:    %[[ZERO:.*]] = arith.constant 0
+// CHECK-DAG:    %[[BITS:.*]] = arith.constant 32
+// CHECK: tensor.empty
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
+// CHECK-DAG:    %[[SHIFT:.*]] = arith.shli %[[LHS]], %[[RHS]] : i32
+// CHECK-DAG:    %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]]
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]]
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
+func.func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
+                             %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_right_arithmetic
+// CHECK-DAG:    %[[BITS:.*]] = arith.constant 32
+// CHECK-DAG:    %[[MAX_SHIFT:.*]] = arith.constant 31
+// CHECK: tensor.empty
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
+// CHECK-DAG:    %[[SHIFT:.*]] = arith.shrsi %[[LHS]], %[[RHS]] : i32
+// CHECK-DAG:    %[[MAX_SHIFTED:.*]] = arith.shrsi %[[LHS]], %[[MAX_SHIFT]] : i32
+// CHECK-DAG:    %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]]
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[MAX_SHIFTED]]
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
+func.func @shift_right_logical(%lhs: tensor<2x2xi32>,
+                          %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "stablehlo.shift_right_logical"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_right_logical
+// CHECK-DAG:    %[[ZERO:.*]] = arith.constant 0
+// CHECK-DAG:    %[[BITS:.*]] = arith.constant 32
+// CHECK: tensor.empty
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
+// CHECK-DAG:    %[[SHIFT:.*]] = arith.shrui %[[LHS]], %[[RHS]] : i32
+// CHECK-DAG:    %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]]
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]]
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -1530,6 +1651,78 @@
 
 // -----
 
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @float_pow
+func.func @float_pow(%lhs: tensor<2x2xf32>,
+                %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-z0-9_]*}}
+  // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
+  // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
+                                   tensor<2x2xf32>) -> tensor<2x2xf32>
+  func.return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @complex_pow
+func.func @complex_pow(%lhs: tensor<2x2xcomplex<f32>>,
+                %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-z0-9_]*}}
+  // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: complex<f32>
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: complex<f32>
+  // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.pow %[[ARG0]], %[[ARG1]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
+                                   tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
+  func.return %0 : tensor<2x2xcomplex<f32>>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @integer_pow
+func.func @integer_pow(%lhs: tensor<2x2xi32>,
+                  %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+                    // CHECK: linalg.generic
+  // CHECK: ^{{[a-z0-9_]*}}
+  // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+  // CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1
+  // CHECK-SAME: iter_args(
+  // CHECK-SAME:   %[[ITER0:.*]] = %c1
+  // CHECK-SAME:   %[[ITER1:.*]] = %[[ARG0]],
+  // CHECK-SAME:   %[[ITER2:.*]] = %[[ARG1]]
+  // CHECK-SAME: ) -> (i32, i32, i32) {
+  //   CHECK: %[[AND:[a-zA-Z0-9_]*]] = arith.andi %[[ITER2]], %c1
+  //   CHECK: %[[COND:[a-zA-Z0-9_]*]] = arith.cmpi eq, %[[AND]], %c1
+  //   CHECK: %[[MUL:[a-zA-Z0-9_]*]] = arith.muli %[[ITER0]], %[[ITER1]]
+  //   CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = arith.select %[[COND]], %[[MUL]], %[[ITER0]]
+  //   CHECK: %[[BASE:[a-zA-Z0-9_]*]] = arith.muli %[[ITER1]], %[[ITER1]]
+  //   CHECK: %[[EXP:[a-zA-Z0-9_]*]] = arith.shrui %[[ITER2]], %c1
+  //   CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]]
+  // CHECK: %[[RHS_PARITY:.*]] = arith.remsi %[[ARG1]], %c2
+  // CHECK: %[[RHS_EVEN:.*]] = arith.cmpi eq, %[[RHS_PARITY]], %c0
+  // CHECK: %[[RHS_NEG:.*]] = arith.cmpi slt, %[[ARG1]], %c0
+  // CHECK: %[[LHS_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c1
+  // CHECK: %[[LHS_NEG_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c-1
+  // CHECK: %[[VAL5:.*]] = arith.extui %[[LHS_ONE]] : i1 to i32
+  // CHECK: %[[VAL6:.*]] = arith.select %[[RHS_EVEN]], %c1{{.*}}, %c-1
+  // CHECK: %[[VAL7:.*]] = arith.select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
+  // CHECK: %[[RESULT:.*]] = arith.select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
+                                   tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %0 : tensor<2x2xi32>
+}
+
+
+// -----
+
 func.func @map_mixed(%arg0: tensor<?xf32>,
                      %arg1: tensor<4xf32>) -> tensor<?xf32> {
   %0 = "stablehlo.map"(%arg0, %arg1) ({
@@ -1703,6 +1896,26 @@
 //      CHECK:   %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
 //      CHECK:   linalg.yield %[[VAL2]] : i32
 
+// -----
+
+// CHECK-LABEL: @real_real
+// CHECK-SAME: (%[[ARG0:.*]]:
+func.func @real_real(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %1 = "stablehlo.real"(%arg0) : (tensor<?xf32>) -> (tensor<?xf32>)
+  // CHECK: return %[[ARG0]]
+  func.return %1 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @imag_real
+func.func @imag_real(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %1 = "stablehlo.imag"(%arg0) : (tensor<?xf32>) -> (tensor<?xf32>)
+  // CHECK: %[[CST:.*]] = arith.constant 0
+  // CHECK: linalg.generic
+  // CHECK: yield %[[CST]]
+  func.return %1 : tensor<?xf32>
+}
 
 // -----
 
@@ -1820,6 +2033,47 @@
 
 // -----
 
+// CHECK-LABEL: func @reduce_precision(
+// CHECK-DAG: %[[C2:.*]] = arith.constant 1048576 : i32
+// CHECK-DAG: %[[C_21:.*]] = arith.constant 20 : i32
+// CHECK-DAG: %[[C3:.*]] = arith.constant 524287 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant -1048576 : i32
+// CHECK-DAG: %[[C5:.*]] = arith.constant 2139095040 : i32
+// CHECK-DAG: %[[C6:.*]] = arith.constant 1090519040 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 1040187392 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant -2147483648 : i32
+// CHECK-DAG: %[[C9:.*]] = arith.constant 2147483647 : i32
+// CHECK: linalg.generic
+// CHECK: %[[X_AS_INT:.*]] = arith.bitcast %[[IN:.*]] : f32 to i32
+// CHECK: %[[ABS_X:.*]] = arith.andi %[[X_AS_INT]], %[[C9]]
+// CHECK: %[[IS_NAN:.*]] = arith.cmpi ugt, %[[ABS_X]], %[[C5]]
+// CHECK: %[[MASKED:.*]] = arith.andi %[[X_AS_INT]], %[[C2]] : i32
+// CHECK: %[[V0:.*]] = arith.shrui %[[MASKED]], %[[C_21]] : i32
+// CHECK: %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : i32
+// CHECK: %[[V2:.*]] = arith.addi %[[X_AS_INT]], %[[V1]] : i32
+// CHECK: %[[V3:.*]] = arith.andi %[[V2]], %[[C4]] : i32
+// CHECK: %[[V4:.*]] = arith.andi %[[V3]], %[[C5]] : i32
+// CHECK: %[[V5:.*]] = arith.cmpi ugt, %[[V4]], %[[C6]] : i32
+// CHECK: %[[V6:.*]] = arith.cmpi ule, %[[V4]], %[[C7]] : i32
+// CHECK: %[[V7:.*]] = arith.andi %[[V3]], %[[C8]] : i32
+// CHECK: %[[V8:.*]] = arith.ori %[[V7]], %[[C5]] : i32
+// CHECK: %[[V9:.*]] = arith.select %[[V5]], %[[V8]], %[[V3]] : i32
+// CHECK: %[[V10:.*]] = arith.select %[[V6]], %[[V7]], %[[V9]] : i32
+// CHECK: %[[CONVERTED:.*]] = arith.bitcast %[[V10]] : i32 to f32
+// CHECK: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[IN]], %[[CONVERTED]]
+// CHECK: linalg.yield %[[RESULT]]
+
+// CHECK-PRIMITIVE-LABEL: func @reduce_precision(
+// CHECK-PRIMITIVE: linalg.map
+func.func @reduce_precision(%arg0: tensor<1x2x3x4xf32>)
+                            -> tensor<1x2x3x4xf32> {
+  %0 = "stablehlo.reduce_precision"(%arg0) {exponent_bits=3:i32, mantissa_bits=3:i32} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+  return %0 : tensor<1x2x3x4xf32>
+}
+
+
+// -----
+
 // CHECK-LABEL: set_dimension_size
 // CHECK-SAME: %[[VALUE:.*]]: tensor<2x?xf32, #stablehlo.bounds<?, 2>
 func.func @set_dimension_size(
@@ -1897,6 +2151,75 @@
 
 // -----
 
+// CHECK-LABEL: func @minf
+func.func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %0 = "stablehlo.minimum"(%lhs, %rhs) {someattr}
+          : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  func.return %0 : tensor<2x2xf32>
+}
+// CHECK: tensor.empty() : tensor<2x2xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: {someattr}
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.minf %[[LHS_IN]], %[[RHS_IN]] : f32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
+
+// CHECK-PRIMITIVE: linalg.map
+// CHECK-PRIMITIVE: arith.minf
+
+// -----
+
+// CHECK-LABEL: func @maxi
+func.func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %0 = "stablehlo.maximum"(%lhs, %rhs)
+          : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  func.return %0 : tensor<2x2xi32>
+}
+// CHECK: tensor.empty() : tensor<2x2xi32>
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.maxsi %[[LHS_IN]], %[[RHS_IN]] : i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// CHECK-PRIMITIVE: linalg.map
+// CHECK-PRIMITIVE: arith.maxsi
+
+// -----
+
+// CHECK-LABEL: func @maxu
+func.func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
+  %0 = "stablehlo.maximum"(%lhs, %rhs)
+          : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
+  func.return %0 : tensor<2x2xui32>
+}
+// CHECK: tensor.empty() : tensor<2x2xi32>
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// CHECK-PRIMITIVE: linalg.map
+// CHECK-PRIMITIVE: arith.maxui
+
+// -----
+
+// CHECK-LABEL: func @maxi1
+func.func @maxi1(%lhs: tensor<?x?xi1>, %rhs: tensor<?x?xi1>) -> tensor<?x?xi1> {
+  %0 = "stablehlo.maximum"(%lhs, %rhs)
+          : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
+  func.return %0 : tensor<?x?xi1>
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i1, %[[RHS_IN:.*]]: i1, %{{.*}}: i1):
+// CHECK-NEXT:   %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i1
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i1
+
+// CHECK-PRIMITIVE: linalg.map
+// CHECK-PRIMITIVE: arith.maxui
+
+
+// -----
+
 func.func @dot_matmul(%arg0: tensor<2x3xf32>,
                  %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
   %0 = "stablehlo.dot"(%arg0, %arg1) {someattr}
@@ -2052,3 +2375,111 @@
 // CHECK: linalg.dot
 // CHECK-SAME: ins(%{{.*}} : tensor<?xi32>, tensor<?xi32>)
 // CHECK-SAME: outs(%[[FILL]] : tensor<i32>)
+
+// -----
+
+// CHECK-LABEL: @clamp_static
+// CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32>
+func.func @clamp_static(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
+    -> tensor<4xf32> {
+  // CHECK: %[[INIT:.*]] = tensor.empty
+  // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
+  // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
+  // CHECK:   %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+  // CHECK:   %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32
+  // CHECK:   linalg.yield %[[MIN]]
+  // CHECK: } -> tensor<4xf32>
+  // CHECK: return %[[RESULT]] : tensor<4xf32>
+  %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<4xf32>,
+      tensor<4xf32>) -> tensor<4xf32>
+  func.return %0 : tensor<4xf32>
+}
+
+// CHECK-PRIMITIVE-LABEL: @clamp_static
+// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32>
+
+// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty
+// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
+// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32)
+// CHECK-PRIMITIVE:   %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+// CHECK-PRIMITIVE:   %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32
+// CHECK-PRIMITIVE:   linalg.yield %[[MIN]]
+// CHECK-PRIMITIVE: return %[[RESULT]] : tensor<4xf32>
+
+// -----
+
+// CHECK-LABEL: @clamp_dynamic
+// CHECK-SAME: %[[LB:.*]]: tensor<?xf32>, %[[X:.*]]: tensor<?xf32>, %[[UB:.*]]: tensor<?xf32>
+func.func @clamp_dynamic(%lb : tensor<?xf32>, %x : tensor<?xf32>, %ub : tensor<?xf32>)
+    -> tensor<?xf32> {
+  // CHECK: %[[INIT:.*]] = tensor.empty
+  // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>)
+  // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
+  // CHECK:   %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+  // CHECK:   %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32
+  // CHECK:   linalg.yield %[[MIN]]
+  // CHECK: } -> tensor<?xf32>
+  // CHECK: return %[[RESULT]] : tensor<?xf32>
+  %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<?xf32>, tensor<?xf32>,
+      tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK-PRIMITIVE-LABEL: @clamp_dynamic
+// CHECK-PRIMITIVE: linalg.map
+
+// -----
+
+func.func @clamp_mixed(%lb : tensor<4xf32>, %x : tensor<?xf32>, %ub : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<?xf32>,
+      tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @clamp_mixed
+// CHECK: linalg.generic
+
+// CHECK-PRIMITIVE-LABEL: @clamp_mixed
+// CHECK-PRIMITIVE: linalg.map
+
+// -----
+
+func.func @clamp_scalar(%lb : tensor<f32>, %x : tensor<?xf32>, %ub : tensor<f32>)
+    -> tensor<?xf32> {
+  %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<f32>, tensor<?xf32>,
+      tensor<f32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @clamp_scalar
+// CHECK: linalg.generic
+
+// CHECK-PRIMITIVE-LABEL: @clamp_scalar
+// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor<f32>, %[[X:.*]]: tensor<?xf32>, %[[UB:.*]]: tensor<f32>
+
+// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.empty
+// CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]]
+// CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]]
+// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>)
+// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32)
+// CHECK-PRIMITIVE:   %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+// CHECK-PRIMITIVE:   %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32
+// CHECK-PRIMITIVE:   linalg.yield %[[MIN]]
+// CHECK-PRIMITIVE: return %[[RESULT]]
+
+
+// -----
+
+func.func @clamp_scalar_mixed(%lb : tensor<f32>, %x : tensor<?xf32>, %ub : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<f32>, tensor<?xf32>,
+      tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @clamp_scalar_mixed
+// CHECK: linalg.generic
+
+// CHECK-PRIMITIVE-LABEL: @clamp_scalar_mixed
+// CHECK-PRIMITIVE: linalg.map