Correct root op to be leaf op. (#4172)
As discussed in https://github.com/google/iree/issues/3924
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index c3c3415..3e16e93 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -227,10 +227,10 @@
op) ||
(!clEnableConsumerOnlyFusion &&
isa<mhlo::DotOp, mhlo::DotGeneralOp>(op)) ||
- isRootOnlyOp(op);
+ isLeafOnlyOp(op);
}
-bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
+bool OpDispatchPolicy::isLeafOnlyOp(Operation *op) {
return isa<mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op);
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index e5726fb..6d82d22 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -56,8 +56,8 @@
// Returns true if |op| is not able to fuse with either producer or consumer.
static bool isUnsupportedFusionOp(Operation *op);
- // Returns true if |op| can only be a root op.
- static bool isRootOnlyOp(Operation *op);
+ // Returns true if |op| can only be a leaf op.
+ static bool isLeafOnlyOp(Operation *op);
// Returns true if the given |op| can be dispatched in all cases.
// Other passes may handle special cases of these ops but this initial
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 2cbfaba..107b711 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -201,10 +201,10 @@
// that substituting library calls is easier.
for (auto &block : regionOp.body().getBlocks()) {
for (auto &op : block) {
- // A root only op is mergable.
+ // A leaf only op is mergable.
if ((OpDispatchPolicy::isUnsupportedFusionOp(&op) ||
OpDispatchPolicy::isFusableWithConsumersOnly(&op)) &&
- !OpDispatchPolicy::isRootOnlyOp(&op)) {
+ !OpDispatchPolicy::isLeafOnlyOp(&op)) {
return false;
}
}
@@ -212,9 +212,9 @@
return regionOp.body().getBlocks().size() == 1;
}
-// Returns true if rhs has ops that can only be root op and will lose the
+// Returns true if rhs has ops that can only be leaf op and will lose the
// characteristic if merge two dispatch regions.
-bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
+bool rhsHasLeafOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
auto &rhsBlock = rhs.body().front();
auto lhsArgs = llvm::to_vector<8>(lhs.args());
auto rhsArgs = llvm::to_vector<8>(rhs.args());
@@ -223,7 +223,7 @@
++lhsResultIdx) {
if (rhsArgs[rhsOpIdx] != lhs.getResult(lhsResultIdx)) continue;
for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
- if (OpDispatchPolicy::isRootOnlyOp(user)) return true;
+ if (OpDispatchPolicy::isLeafOnlyOp(user)) return true;
}
}
}
@@ -364,8 +364,8 @@
LLVM_DEBUG(llvm::dbgs()
<< " -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
}
- if (rhsHasRootOnlyOp(lhs, rhs)) {
- LLVM_DEBUG(llvm::dbgs() << " -RHS REGION HAS ROOT OP-\n");
+ if (rhsHasLeafOnlyOp(lhs, rhs)) {
+ LLVM_DEBUG(llvm::dbgs() << " -RHS REGION HAS LEAF OP-\n");
continue;
}
mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index d551d13..e0ba42b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -135,11 +135,11 @@
// -----
-// Test if the op that only can be a root op fuse with consumer but not
-// producer. This test use a dummy workload to test on root only op
+// Test if the op that only can be a leaf op fuse with consumer but not
+// producer. This test use a dummy workload to test on leaf only op
// functionality.
module {
- func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ func @leafOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%c0 = constant 0 : index
%0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
%3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
@@ -156,7 +156,7 @@
return %2 : tensor<1x2xi32>
}
}
-// CHECK-LABEL: func @rootOnlyOp
+// CHECK-LABEL: func @leafOnlyOp
// CHECK: flow.dispatch.region
// CHECK-NEXT: mhlo.add
// CHECK: flow.dispatch.region