Relax SliceOp to being able to be a root op. (#3876)
- Move mergable checks to DispatchConfig.
- Add a new trait for root only ops.
- Add slice + add tests.
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 9f33bd1..7381818 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -665,44 +665,55 @@
namespace {
/// Converts mhlo.slice operation to linalg.subview + linalg.copy
-struct SliceOpConversion
- : public ConvertToLinalgBufferOp<SliceOpConversion, mhlo::SliceOp> {
- using ConvertToLinalgBufferOp<SliceOpConversion,
- mhlo::SliceOp>::ConvertToLinalgBufferOp;
+struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
+ SliceOpConversion(MLIRContext *context,
+ TensorToBufferMap const &resultTensorToBufferMap,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<mhlo::SliceOp>(context, benefit),
+ resultTensorToBufferMap(resultTensorToBufferMap) {}
- LogicalResult apply(mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const;
+ LogicalResult matchAndRewrite(
+ mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
+ if (!argType || !argType.hasStaticShape()) {
+ return op.emitError("expected static shape");
+ }
+
+ auto resultShape = op.getResult().getType().cast<ShapedType>().getShape();
+ SmallVector<Value, 3> offsets, sizes, strides;
+ for (int i = 0, e = argType.getRank(); i < e; ++i) {
+ Value startIndex = rewriter.create<ConstantIndexOp>(
+ loc, op.start_indices().getValue<int64_t>(i));
+ offsets.push_back(startIndex);
+ Value size = rewriter.create<ConstantIndexOp>(loc, resultShape[i]);
+ sizes.push_back(size);
+ Value stride = rewriter.create<ConstantIndexOp>(
+ loc, op.strides().getValue<int64_t>(i));
+ strides.push_back(stride);
+ }
+ auto subViewOp = rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets,
+ sizes, strides);
+
+ // If the result of the subview is already mapped to a buffer, a copy is
+ // required from the buffer above into the mapped buffer.
+ if (Value bufferForResult =
+ resultTensorToBufferMap.lookup(op.getResult())) {
+ rewriter.create<linalg::CopyOp>(loc, subViewOp, bufferForResult);
+ rewriter.replaceOp(op, bufferForResult);
+ } else {
+ rewriter.replaceOp(op, subViewOp.getResult());
+ }
+
+ return success();
+ }
+
+ private:
+ TensorToBufferMap const &resultTensorToBufferMap;
};
} // namespace
-LogicalResult SliceOpConversion::apply(
- mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
- if (!argType || !argType.hasRank()) {
- return op.emitError("expected known-rank args");
- }
-
- SmallVector<Value, 3> offsets, sizes, strides;
- for (int i = 0, e = argType.getRank(); i < e; ++i) {
- Value startIndex = rewriter.create<ConstantIndexOp>(
- loc, op.start_indices().getValue<int64_t>(i));
- offsets.push_back(startIndex);
- Value size = rewriter.create<DimOp>(loc, resultBuffers[0], i);
- sizes.push_back(size);
- Value stride = rewriter.create<ConstantIndexOp>(
- loc, op.strides().getValue<int64_t>(i));
- strides.push_back(stride);
- }
- auto subViewOp =
- rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets, sizes, strides);
- rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// mhlo.reduce_window conversion patterns and utility functions.
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
index 406a62b..0de43b3 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline -canonicalize %s | IreeFileCheck %s
module {
// CHECK_LABEL: @slice_whole_buffer
@@ -25,19 +25,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 4)>
// CHECK: @slice_whole_stride
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x4xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x4xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x4xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ZERO]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
+ // CHECK: subview %[[IN]][1, 0] [1, 4] [1, 1] : memref<3x4xi32> to memref<1x4xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_whole_stride()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
@@ -60,19 +51,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
// CHECK: @slice_stride_part
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x2xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x2xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x2xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map>
+ // CHECK: subview %[[IN]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_stride_part()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
@@ -91,3 +73,32 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @slice_stride_part
+// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<1x2xi32>
+// CHECK: %[[IN0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4xi32>
+// CHECK: %[[SUBVIEW:.+]] = subview %[[IN0]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP0]]>
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<1x2xi32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUBVIEW]], %[[IN1]]
+// CHECK-SAME: outs(%[[OUT]]
+module {
+ func @slice_stride_part() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 {operand_result_index = 0 : i32} : tensor<3x4xi32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 {operand_result_index = 1 : i32} : tensor<1x2xi32>
+ %2 = "mhlo.slice"(%0) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %3 = mhlo.add %2, %1 : tensor<1x2xi32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<1x2xi32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 176f09d..9964eae 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -31,12 +31,6 @@
namespace {
// TODO(laurenzo): Every one of these should have better support and removed
// from this exclusion list eventually.
-bool isUnsupportedFusionOp(Operation *op) {
- return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
- mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
- mhlo::TorchIndexSelectOp>(op);
-}
-
// Allowlist of ops that materialize to a an index-permuted copy of some kind
// if they exist standalone. Generally we try to avoid anchoring on these,
// letting them fuse into more meaningful ops as possible.
@@ -182,6 +176,18 @@
return FusionType::DISABLED;
}
+// TODO(b/144530470): replace with tablegen attributes/interfaces.
+bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
+ return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
+ mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
+ mhlo::TorchIndexSelectOp>(op) ||
+ isRootOnlyOp(op);
+}
+
+bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
+ return isa<mhlo::SliceOp>(op);
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index 22666b3..ee9299f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,12 @@
OpDispatchPolicy(Dispatchability &dispatchability)
: dispatchability(dispatchability) {}
+ // Returns true if |op| is not able to fuse with either producer or consumer.
+ static bool isUnsupportedFusionOp(Operation *op);
+
+ // Returns true if |op| can only be a root op.
+ static bool isRootOnlyOp(Operation *op);
+
// Returns true if the given |op| can be dispatched in all cases.
// Other passes may handle special cases of these ops but this initial
// identification is conservative.
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 984b8c7..9565747 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
@@ -200,10 +201,9 @@
// that substituting library calls is easier.
for (auto &block : regionOp.body().getBlocks()) {
for (auto &op : block) {
- // TODO(b/144530470): replace with tablegen attributes/interfaces.
- if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
- mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
+ // A root only op is mergable.
+ if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+ !OpDispatchPolicy::isRootOnlyOp(&op)) {
return false;
}
}
@@ -211,6 +211,24 @@
return regionOp.body().getBlocks().size() == 1;
}
+// Returns true if rhs has ops that can only be root op and will lose the
+// characteristic if merge two dispatch regions.
+bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
+ auto &rhsBlock = rhs.body().front();
+ auto lhsArgs = llvm::to_vector<8>(lhs.args());
+ auto rhsArgs = llvm::to_vector<8>(rhs.args());
+ for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) {
+ for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
+ ++lhsResultIdx) {
+ if (rhsArgs[rhsOpIdx] != lhs.getResult(lhsResultIdx)) continue;
+ for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
+ if (OpDispatchPolicy::isRootOnlyOp(user)) return true;
+ }
+ }
+ }
+ return false;
+}
+
// Merges |rhs| into |lhs| and returns the new |lhs| op.
// Precondition: !areDispatchRegionsTransitivelyDependent
DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs,
@@ -345,6 +363,10 @@
LLVM_DEBUG(llvm::dbgs()
<< " -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
}
+ if (rhsHasRootOnlyOp(lhs, rhs)) {
+ LLVM_DEBUG(llvm::dbgs() << " -RHS REGION HAS ROOT OP-\n");
+ continue;
+ }
mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
if (!mergableRegions[i]) {
return failure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index 938c92a..c324ef7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -132,3 +132,33 @@
// CHECK-LABEL: func @dominate
// CHECK: flow.dispatch.region
// CHECK-NOT: flow.dispatch.region
+
+// -----
+
+// Test if the op that only can be a root op fuse with consumer but not
+// producer. This test use a dummy workload to test on root only op
+// functionality.
+module {
+ func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %c0 = constant 0 : index
+ %0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
+ %3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
+ flow.return %3 : tensor<3x4xi32>
+ }
+ %1 = flow.dispatch.region[%c0 : index](%arg2 = %0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
+ %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ %2 = flow.dispatch.region[%c0 : index](%arg2 = %1 : tensor<1x2xi32>, %arg3 = %arg1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %3 = mhlo.multiply %arg2, %arg3 : tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ return %2 : tensor<1x2xi32>
+ }
+}
+// CHECK-LABEL: func @rootOnlyOp
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.add
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.slice
+// CHECK-NEXT: mhlo.multiply
diff --git a/iree/test/e2e/regression/slice_add.mlir b/iree/test/e2e/regression/slice_add.mlir
new file mode 100644
index 0000000..762d979
--- /dev/null
+++ b/iree/test/e2e/regression/slice_add.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+
+// CHECK: EXEC @slice_stride_part
+// CHECK: 1x2xi32=[16 17]
+func @slice_stride_part(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %1 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[1, 1]> : tensor<2xi64>,
+ limit_indices = dense<[2, 3]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %2 = mhlo.add %1, %arg1 : tensor<1x2xi32>
+ return %2 : tensor<1x2xi32>
+}