Extract scalars when using outside tensors within LinalgExt body. (#6844)
Sort is a weird operation, which sometimes takes outside tensors to
compare. We only allow LinalgExt body to take scalars, so we have to
create a tensor.extract op in this case.
Also adds more op lowering for LinalgExt region.
It's a step toward https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 6fd50a6..c976b49 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -61,13 +61,19 @@
ConversionPatternRewriter &rewriter) const final {
if (!isInBodyOfLinalgExtOps(op)) return failure();
if (!op.getResult().getType().template isa<TensorType>()) return failure();
- if (llvm::all_of(args, [](Value arg) {
- return arg.getType().template isa<TensorType>();
- })) {
- return failure();
+ SmallVector<Value> scalarArgs;
+ for (auto arg : args) {
+ if (auto ty = arg.getType().template dyn_cast<TensorType>()) {
+ assert(ty.hasRank() && ty.getRank() == 0 &&
+ "Have non-0D tensors in the region?");
+ scalarArgs.push_back(
+ rewriter.create<tensor::ExtractOp>(op.getLoc(), arg));
+ } else {
+ scalarArgs.push_back(arg);
+ }
}
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
- op, getElementTypeOrSelf(op.getType()), args, &rewriter);
+ op, getElementTypeOrSelf(op.getType()), scalarArgs, &rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -396,6 +402,8 @@
context);
patterns.insert<LinalgExtRegionHLOOpConversion<mhlo::CompareOp>,
LinalgExtRegionHLOOpConversion<mhlo::AddOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SubOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::BitcastConvertOp>,
LinalgExtRegionReturnOpConversion>(context,
PatternBenefit(1000));
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 163df58..6c5b39a 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -18,6 +18,29 @@
// CHECK: linalg_ext.yield %[[CMP]]
// CHECK: return %[[SORT]]
+// ----
+
+func @sort_with_cst(%arg0: tensor<1x10xi32>) -> tensor<1x10xi32> {
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ %1 = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<i32>, %arg3: tensor<i32>): // no predecessors
+ %2 = "mhlo.compare"(%arg1, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%2) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<1x10xi32>) -> tensor<1x10xi32>
+ return %1 : tensor<1x10xi32>
+}
+
+// CHECK-LABEL: func @sort_with_cst
+// CHECK: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[CST:.+]] = mhlo.constant dense<0> : tensor<i32>
+// CHECK: %{{.+}} = linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>) {
+// CHECK: ^bb0(%[[ARG1:.+]]: i32, %{{.*}}: i32)
+// CHECK: %[[SCALAR:.+]] = tensor.extract %[[CST]][] : tensor<i32>
+// CHECK: %[[RES:.+]] = cmpi slt, %[[ARG1]], %[[SCALAR]] : i32
+// CHECK: linalg_ext.yield %[[RES]] : i1
+// CHECK: } -> tensor<1x10xi32>
+// CHECK: }
+
// -----
func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) {