Extract scalars when using outside tensors within LinalgExt body. (#6844)

Sort is a weird operation, which sometimes takes outside tensors to
compare. We only allow LinalgExt body to take scalars, so we have to
create a tensor.extract op in this case.

Also adds more op lowering for LinalgExt region.

It's a step toward https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 6fd50a6..c976b49 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -61,13 +61,19 @@
       ConversionPatternRewriter &rewriter) const final {
     if (!isInBodyOfLinalgExtOps(op)) return failure();
     if (!op.getResult().getType().template isa<TensorType>()) return failure();
-    if (llvm::all_of(args, [](Value arg) {
-          return arg.getType().template isa<TensorType>();
-        })) {
-      return failure();
+    SmallVector<Value> scalarArgs;
+    for (auto arg : args) {
+      if (auto ty = arg.getType().template dyn_cast<TensorType>()) {
+        assert(ty.hasRank() && ty.getRank() == 0 &&
+               "Have non-0D tensors in the region?");
+        scalarArgs.push_back(
+            rewriter.create<tensor::ExtractOp>(op.getLoc(), arg));
+      } else {
+        scalarArgs.push_back(arg);
+      }
     }
     Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
-        op, getElementTypeOrSelf(op.getType()), args, &rewriter);
+        op, getElementTypeOrSelf(op.getType()), scalarArgs, &rewriter);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -396,6 +402,8 @@
         context);
     patterns.insert<LinalgExtRegionHLOOpConversion<mhlo::CompareOp>,
                     LinalgExtRegionHLOOpConversion<mhlo::AddOp>,
+                    LinalgExtRegionHLOOpConversion<mhlo::SubOp>,
+                    LinalgExtRegionHLOOpConversion<mhlo::BitcastConvertOp>,
                     LinalgExtRegionReturnOpConversion>(context,
                                                        PatternBenefit(1000));
 
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 163df58..6c5b39a 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -18,6 +18,29 @@
 // CHECK:             linalg_ext.yield %[[CMP]]
 // CHECK:         return %[[SORT]]
 
+// ----
+
+func @sort_with_cst(%arg0: tensor<1x10xi32>) -> tensor<1x10xi32> {
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  %1 = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<i32>, %arg3: tensor<i32>):  // no predecessors
+    %2 = "mhlo.compare"(%arg1, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%2) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<1x10xi32>) -> tensor<1x10xi32>
+  return %1 : tensor<1x10xi32>
+}
+
+// CHECK-LABEL: func @sort_with_cst
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[CST:.+]] = mhlo.constant dense<0> : tensor<i32>
+// CHECK:         %{{.+}} = linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>)  {
+// CHECK:         ^bb0(%[[ARG1:.+]]: i32, %{{.*}}: i32)
+// CHECK:           %[[SCALAR:.+]] = tensor.extract %[[CST]][] : tensor<i32>
+// CHECK:           %[[RES:.+]] = cmpi slt, %[[ARG1]], %[[SCALAR]] : i32
+// CHECK:           linalg_ext.yield %[[RES]] : i1
+// CHECK:         } -> tensor<1x10xi32>
+// CHECK:       }
+
 // -----
 
 func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) {