Relax SliceOp to being able to be a root op. (#3876)

- Move mergable checks to DispatchConfig.
- Add a new trait for root only ops.
- Add slice + add tests.
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 9f33bd1..7381818 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -665,44 +665,55 @@
 
 namespace {
 /// Converts mhlo.slice operation to linalg.subview + linalg.copy
-struct SliceOpConversion
-    : public ConvertToLinalgBufferOp<SliceOpConversion, mhlo::SliceOp> {
-  using ConvertToLinalgBufferOp<SliceOpConversion,
-                                mhlo::SliceOp>::ConvertToLinalgBufferOp;
+struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
+  SliceOpConversion(MLIRContext *context,
+                    TensorToBufferMap const &resultTensorToBufferMap,
+                    PatternBenefit benefit = 1)
+      : OpConversionPattern<mhlo::SliceOp>(context, benefit),
+        resultTensorToBufferMap(resultTensorToBufferMap) {}
 
-  LogicalResult apply(mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
-                      ArrayRef<Value> resultBuffers,
-                      ConversionPatternRewriter &rewriter) const;
+  LogicalResult matchAndRewrite(
+      mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
+    if (!argType || !argType.hasStaticShape()) {
+      return op.emitError("expected static shape");
+    }
+
+    auto resultShape = op.getResult().getType().cast<ShapedType>().getShape();
+    SmallVector<Value, 3> offsets, sizes, strides;
+    for (int i = 0, e = argType.getRank(); i < e; ++i) {
+      Value startIndex = rewriter.create<ConstantIndexOp>(
+          loc, op.start_indices().getValue<int64_t>(i));
+      offsets.push_back(startIndex);
+      Value size = rewriter.create<ConstantIndexOp>(loc, resultShape[i]);
+      sizes.push_back(size);
+      Value stride = rewriter.create<ConstantIndexOp>(
+          loc, op.strides().getValue<int64_t>(i));
+      strides.push_back(stride);
+    }
+    auto subViewOp = rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets,
+                                                sizes, strides);
+
+    // If the result of the subview is already mapped to a buffer, a copy is
+    // required from the buffer above into the mapped buffer.
+    if (Value bufferForResult =
+            resultTensorToBufferMap.lookup(op.getResult())) {
+      rewriter.create<linalg::CopyOp>(loc, subViewOp, bufferForResult);
+      rewriter.replaceOp(op, bufferForResult);
+    } else {
+      rewriter.replaceOp(op, subViewOp.getResult());
+    }
+
+    return success();
+  }
+
+ private:
+  TensorToBufferMap const &resultTensorToBufferMap;
 };
 }  // namespace
 
-LogicalResult SliceOpConversion::apply(
-    mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
-    ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
-  auto loc = op.getLoc();
-  auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
-  if (!argType || !argType.hasRank()) {
-    return op.emitError("expected known-rank args");
-  }
-
-  SmallVector<Value, 3> offsets, sizes, strides;
-  for (int i = 0, e = argType.getRank(); i < e; ++i) {
-    Value startIndex = rewriter.create<ConstantIndexOp>(
-        loc, op.start_indices().getValue<int64_t>(i));
-    offsets.push_back(startIndex);
-    Value size = rewriter.create<DimOp>(loc, resultBuffers[0], i);
-    sizes.push_back(size);
-    Value stride = rewriter.create<ConstantIndexOp>(
-        loc, op.strides().getValue<int64_t>(i));
-    strides.push_back(stride);
-  }
-  auto subViewOp =
-      rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets, sizes, strides);
-  rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // mhlo.reduce_window conversion patterns and utility functions.
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
index 406a62b..0de43b3 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline -canonicalize %s | IreeFileCheck %s
 
 module {
   // CHECK_LABEL: @slice_whole_buffer
@@ -25,19 +25,10 @@
 // -----
 
 module {
-  //      CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+  //      CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 4)>
   //      CHECK: @slice_whole_stride
-  //  CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x4xi32>
   //  CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
-  //  CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
-  //  CHECK-DAG: %[[ONE:.+]] = constant 1 : index
-  //  CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x4xi32>
-  //  CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x4xi32>
-  //      CHECK: subview %[[IN]]
-  // CHECK-SAME:   [%[[ONE]], %[[ZERO]]]
-  // CHECK-SAME:   [%[[DIM0]], %[[DIM1]]]
-  // CHECK-SAME:   [%[[ONE]], %[[ONE]]]
-  // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
+  //      CHECK: subview %[[IN]][1, 0] [1, 4] [1, 1]  : memref<3x4xi32> to memref<1x4xi32, #[[MAP]]>
   //      CHECK: linalg.copy
   func @slice_whole_stride()
     attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
@@ -60,19 +51,10 @@
 // -----
 
 module {
-  //      CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+  //      CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
   //      CHECK: @slice_stride_part
-  //  CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x2xi32>
   //  CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
-  //  CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
-  //  CHECK-DAG: %[[ONE:.+]] = constant 1 : index
-  //  CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x2xi32>
-  //  CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x2xi32>
-  //      CHECK: subview %[[IN]]
-  // CHECK-SAME:   [%[[ONE]], %[[ONE]]]
-  // CHECK-SAME:   [%[[DIM0]], %[[DIM1]]]
-  // CHECK-SAME:   [%[[ONE]], %[[ONE]]]
-  // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map>
+  //       CHECK: subview %[[IN]][1, 1] [1, 2] [1, 1]  : memref<3x4xi32> to memref<1x2xi32, #[[MAP]]>
   //       CHECK: linalg.copy
   func @slice_stride_part()
     attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
@@ -91,3 +73,32 @@
     hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
   }
 }
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//      CHECK: func @slice_stride_part
+//      CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<1x2xi32>
+//      CHECK: %[[IN0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4xi32>
+//      CHECK: %[[SUBVIEW:.+]] = subview %[[IN0]][1, 1] [1, 2] [1, 1]  : memref<3x4xi32> to memref<1x2xi32, #[[MAP0]]>
+//      CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<1x2xi32>
+//      CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUBVIEW]], %[[IN1]]
+// CHECK-SAME: outs(%[[OUT]]
+module {
+  func @slice_stride_part() {
+    %c0 = constant 0 : index
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 {operand_result_index = 0 : i32} : tensor<3x4xi32>
+    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 {operand_result_index = 1 : i32} : tensor<1x2xi32>
+    %2 = "mhlo.slice"(%0) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+    %3 = mhlo.add %2, %1 : tensor<1x2xi32>
+    hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<1x2xi32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 176f09d..9964eae 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -31,12 +31,6 @@
 namespace {
 // TODO(laurenzo): Every one of these should have better support and removed
 // from this exclusion list eventually.
-bool isUnsupportedFusionOp(Operation *op) {
-  return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
-             mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
-             mhlo::TorchIndexSelectOp>(op);
-}
-
 // Allowlist of ops that materialize to a an index-permuted copy of some kind
 // if they exist standalone. Generally we try to avoid anchoring on these,
 // letting them fuse into more meaningful ops as possible.
@@ -182,6 +176,18 @@
   return FusionType::DISABLED;
 }
 
+// TODO(b/144530470): replace with tablegen attributes/interfaces.
+bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
+  return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
+             mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
+             mhlo::TorchIndexSelectOp>(op) ||
+         isRootOnlyOp(op);
+}
+
+bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
+  return isa<mhlo::SliceOp>(op);
+}
+
 }  // namespace Flow
 }  // namespace IREE
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index 22666b3..ee9299f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,12 @@
   OpDispatchPolicy(Dispatchability &dispatchability)
       : dispatchability(dispatchability) {}
 
+  // Returns true if |op| is not able to fuse with either producer or consumer.
+  static bool isUnsupportedFusionOp(Operation *op);
+
+  // Returns true if |op| can only be a root op.
+  static bool isRootOnlyOp(Operation *op);
+
   // Returns true if the given |op| can be dispatched in all cases.
   // Other passes may handle special cases of these ops but this initial
   // identification is conservative.
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 984b8c7..9565747 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -200,10 +201,9 @@
   // that substituting library calls is easier.
   for (auto &block : regionOp.body().getBlocks()) {
     for (auto &op : block) {
-      // TODO(b/144530470): replace with tablegen attributes/interfaces.
-      if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
-              mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
-              mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
+      // A root only op is mergable.
+      if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+          !OpDispatchPolicy::isRootOnlyOp(&op)) {
         return false;
       }
     }
@@ -211,6 +211,24 @@
   return regionOp.body().getBlocks().size() == 1;
 }
 
+// Returns true if rhs has ops that can only be root op and will lose the
+// characteristic if merge two dispatch regions.
+bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
+  auto &rhsBlock = rhs.body().front();
+  auto lhsArgs = llvm::to_vector<8>(lhs.args());
+  auto rhsArgs = llvm::to_vector<8>(rhs.args());
+  for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) {
+    for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
+         ++lhsResultIdx) {
+      if (rhsArgs[rhsOpIdx] != lhs.getResult(lhsResultIdx)) continue;
+      for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
+        if (OpDispatchPolicy::isRootOnlyOp(user)) return true;
+      }
+    }
+  }
+  return false;
+}
+
 // Merges |rhs| into |lhs| and returns the new |lhs| op.
 // Precondition: !areDispatchRegionsTransitivelyDependent
 DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs,
@@ -345,6 +363,10 @@
         LLVM_DEBUG(llvm::dbgs()
                    << "   -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
       }
+      if (rhsHasRootOnlyOp(lhs, rhs)) {
+        LLVM_DEBUG(llvm::dbgs() << "   -RHS REGION HAS ROOT OP-\n");
+        continue;
+      }
       mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
       if (!mergableRegions[i]) {
         return failure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index 938c92a..c324ef7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -132,3 +132,33 @@
 // CHECK-LABEL: func @dominate
 //       CHECK: flow.dispatch.region
 //   CHECK-NOT: flow.dispatch.region
+
+// -----
+
+// Test if the op that only can be a root op fuse with consumer but not
+// producer. This test use a dummy workload to test on root only op
+// functionality.
+module {
+  func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+    %c0 = constant 0 : index
+    %0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
+      %3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
+      flow.return %3 : tensor<3x4xi32>
+    }
+    %1 = flow.dispatch.region[%c0 : index](%arg2 = %0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
+      %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+      flow.return %3 : tensor<1x2xi32>
+    }
+    %2 = flow.dispatch.region[%c0 : index](%arg2 = %1 : tensor<1x2xi32>, %arg3 = %arg1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
+      %3 = mhlo.multiply %arg2, %arg3 : tensor<1x2xi32>
+      flow.return %3 : tensor<1x2xi32>
+    }
+    return %2 : tensor<1x2xi32>
+  }
+}
+// CHECK-LABEL: func @rootOnlyOp
+//       CHECK: flow.dispatch.region
+//  CHECK-NEXT:   mhlo.add
+//       CHECK: flow.dispatch.region
+//  CHECK-NEXT:   mhlo.slice
+//  CHECK-NEXT:   mhlo.multiply
diff --git a/iree/test/e2e/regression/slice_add.mlir b/iree/test/e2e/regression/slice_add.mlir
new file mode 100644
index 0000000..762d979
--- /dev/null
+++ b/iree/test/e2e/regression/slice_add.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+
+// CHECK: EXEC @slice_stride_part
+// CHECK: 1x2xi32=[16 17]
+func @slice_stride_part(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+  %1 = "mhlo.slice"(%arg0) {
+    start_indices = dense<[1, 1]> : tensor<2xi64>,
+    limit_indices = dense<[2, 3]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+  %2 = mhlo.add %1, %arg1 : tensor<1x2xi32>
+  return %2 : tensor<1x2xi32>
+}