Revert "Relax TorchIndexSelectOp to fuse with other ops. (#3682)" (#3731)
This reverts commit 3b9a9f05a94fd20ae9b3ff0a7e774d889e94f115.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 635875a..176f09d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -33,8 +33,8 @@
// from this exclusion list eventually.
bool isUnsupportedFusionOp(Operation *op) {
return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
- mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp>(
- op);
+ mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
+ mhlo::TorchIndexSelectOp>(op);
}
// Allowlist of ops that materialize to a an index-permuted copy of some kind
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index a6b7d51..b3c48a0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -199,7 +199,7 @@
// TODO(b/144530470): replace with tablegen attributes/interfaces.
if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::SliceOp>(op)) {
+ mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
return false;
}
}