Apply dispatch.tensor.load canonicalization after materializing launch configs (#5692)
diff --git a/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp
index 0071519..00c813a 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
@@ -131,6 +132,8 @@
OwningRewritePatternList canonicalization(&getContext());
AffineMinOp::getCanonicalizationPatterns(canonicalization, context);
populateAffineMinSCFCanonicalizationPattern(canonicalization);
+ IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
+ context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalization));
}
}
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 72b300b..51f409e 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1684,6 +1684,15 @@
return newOp;
}
+//===----------------------------------------------------------------------===//
+// Public methods
+//===----------------------------------------------------------------------===//
+
+void populateFlowDispatchCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ DispatchTensorLoadOp::getCanonicalizationPatterns(results, context);
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h
index 1d52091..933cff1 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h
@@ -36,4 +36,18 @@
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h.inc"
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// Populates flow.dispatch.* canonicalization patterns.
+void populateFlowDispatchCanonicalizationPatterns(
+ ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context);
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
#endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWOPS_H_