[VMVX] Switch to module-scope pipeline. (#24133)
Similar to other backends, VMVX follows module-scope approach during
codegen and only touch HAL variant ops in pre-processing and
post-processing.
- Move configuration pipeline from `Dialect/VMVX/Transforms/` to
Codegen.
- Move tensor lowering code from `Dialect/VMVX/Transforms/` to Codegen
(i.e., new `buildVMVXLoweringPassPipeline`).
- Introduce configuration pipeline and translation pipeline; use them in
the pipeline tests.
- Adapt the outdated hand-authored `hal_executable.mlir` to carry
ordinal on export op, which is the responsiblitly of
MaterializeInterfaces pass, which was introduced later than the test.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/plugins/target/VMVX/BUILD.bazel b/compiler/plugins/target/VMVX/BUILD.bazel
index 286d119..dcb507b 100644
--- a/compiler/plugins/target/VMVX/BUILD.bazel
+++ b/compiler/plugins/target/VMVX/BUILD.bazel
@@ -23,6 +23,7 @@
"VMVXTarget.cpp",
],
deps = [
+ "//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/VMVX",
diff --git a/compiler/plugins/target/VMVX/CMakeLists.txt b/compiler/plugins/target/VMVX/CMakeLists.txt
index 5010b49..569ccd0 100644
--- a/compiler/plugins/target/VMVX/CMakeLists.txt
+++ b/compiler/plugins/target/VMVX/CMakeLists.txt
@@ -27,6 +27,7 @@
MLIRIR
MLIRPass
MLIRSupport
+ iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::VMVX
diff --git a/compiler/plugins/target/VMVX/VMVXTarget.cpp b/compiler/plugins/target/VMVX/VMVXTarget.cpp
index 7b8f4cd..20ba21f 100644
--- a/compiler/plugins/target/VMVX/VMVXTarget.cpp
+++ b/compiler/plugins/target/VMVX/VMVXTarget.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
@@ -122,12 +123,14 @@
void
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) final {
- IREE::VMVX::buildVMVXConfigurationPassPipeline(passManager);
+ buildCodegenConfigurationPreProcessingPassPipeline(passManager);
+ buildVMVXConfigurationPassPipeline(passManager.nest<ModuleOp>());
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) final {
- IREE::VMVX::buildVMVXTransformPassPipeline(passManager);
+ IREE::VMVX::buildVMVXTransformPassPipeline(passManager.nest<ModuleOp>());
+ buildCodegenTranslationPostProcessingPassPipeline(passManager);
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
@@ -246,12 +249,14 @@
void
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) final {
- IREE::VMVX::buildVMVXConfigurationPassPipeline(passManager);
+ buildCodegenConfigurationPreProcessingPassPipeline(passManager);
+ buildVMVXConfigurationPassPipeline(passManager.nest<ModuleOp>());
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) final {
- IREE::VMVX::buildVMVXTransformPassPipeline(passManager);
+ IREE::VMVX::buildVMVXTransformPassPipeline(passManager.nest<ModuleOp>());
+ buildCodegenTranslationPostProcessingPassPipeline(passManager);
}
private:
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
index 8478d04..9044583 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
@@ -7,6 +7,7 @@
#include "mlir/Transforms/Passes.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
+#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
@@ -78,6 +79,27 @@
}
}
+void buildVMVXConfigurationPassPipeline(OpPassManager &modulePassManager) {
+ {
+ FunctionLikeNest funcPassManager(modulePassManager);
+ addCommonTargetExecutablePreprocessingPasses(funcPassManager);
+ }
+ modulePassManager.addPass(createMaterializeUserConfigsPass());
+ FunctionLikeNest(modulePassManager)
+ .addPass(createMaterializeDeviceEncodingPass)
+ // TODO: Remove the following pass the plumb support for
+ // #hal.descriptor_type memory space through the stack.
+ .addPass(createEraseHALDescriptorTypeFromMemRefPass);
+ modulePassManager.addPass(createVMVXSelectLoweringStrategyPass());
+}
+
+void buildVMVXLoweringPassPipeline(OpPassManager &modulePassManager) {
+ FunctionLikeNest(modulePassManager)
+ .addPass(createVMVXLowerExecutableTargetPass);
+ modulePassManager.addPass(createReconcileTranslationInfoPass());
+ modulePassManager.addPass(createResolveWorkgroupCountHintsPass());
+}
+
// NOTE: this runs on the top-level program module containing all
// hal.executable ops.
void buildVMVXLinkingPassPipeline(OpPassManager &modulePassManager) {
@@ -107,6 +129,20 @@
// Generated.
registerPasses();
+ static PassPipelineRegistration<> VMVXConfigPipeline(
+ "iree-codegen-vmvx-configuration-pipeline",
+ "Runs the VMVX codegen configuration pipeline",
+ [](OpPassManager &modulePassManager) {
+ buildVMVXConfigurationPassPipeline(modulePassManager);
+ });
+
+ static PassPipelineRegistration<> VMVXLoweringPipeline(
+ "iree-codegen-vmvx-lowering-pipeline",
+ "Runs the VMVX codegen lowering pipeline",
+ [](OpPassManager &modulePassManager) {
+ buildVMVXLoweringPassPipeline(modulePassManager);
+ });
+
static PassPipelineRegistration<> VMVXLinkingPipeline(
"iree-codegen-vmvx-linking-pipeline",
"Runs the VMVX HAL executable linking pipeline",
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.h b/compiler/src/iree/compiler/Codegen/VMVX/Passes.h
index 98b27f6..b7beac1 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.h
@@ -28,6 +28,19 @@
bool enableUKernels);
//----------------------------------------------------------------------------//
+// VMVX Codegen Pipelines
+//----------------------------------------------------------------------------//
+
+/// Populates passes needed for preprocessing before codegen lowerings, as well
+/// as high level lowering strategy selection.
+void buildVMVXConfigurationPassPipeline(OpPassManager &modulePassManager);
+
+/// Populates passes needed to lower high level ops to VMVX-compatible ops via
+/// the structured ops path. The `modulePassManager` should operate on the
+/// module within the IREE::HAL::ExecutableOp.
+void buildVMVXLoweringPassPipeline(OpPassManager &modulePassManager);
+
+//----------------------------------------------------------------------------//
// VMVX Linking Passes and Pipelines
//----------------------------------------------------------------------------//
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/pipeline.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/pipeline.mlir
index 6035e66..d684cff 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/pipeline.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/pipeline.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-vmvx-select-lowering-strategy, func.func(iree-vmvx-lower-executable-target))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-vmvx-configuration-pipeline, iree-codegen-vmvx-lowering-pipeline)" --split-input-file %s | FileCheck %s
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD.bazel
index 38728a2..f4e098f 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD.bazel
@@ -57,7 +57,6 @@
"//compiler/src/iree/compiler/Codegen/VMVX",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
- "//compiler/src/iree/compiler/Dialect/HAL/Transforms",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
"//compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil",
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
index 88c9748..567737c 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
@@ -80,7 +80,6 @@
iree::compiler::Codegen::VMVX
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::HAL::Transforms
iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::Conversion::MemRefToUtil
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index 1011eb7..8377ec5 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -11,7 +11,6 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
-#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Transforms/Passes.h"
@@ -31,60 +30,19 @@
namespace mlir::iree_compiler::IREE::VMVX {
// ---------------------------------------------------------------------------
-// Variant configuration
-// ---------------------------------------------------------------------------
-
-void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager) {
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- {
- FunctionLikeNest funcPassManager(modulePassManager);
- // ---------------------------------------------------------------------------
- // Tensor-level optimization, kernel dispatch and lower to buffers.
- // ---------------------------------------------------------------------------
- addCommonTargetExecutablePreprocessingPasses(funcPassManager);
- }
- modulePassManager.addPass(createMaterializeUserConfigsPass());
- FunctionLikeNest(modulePassManager)
- .addPass(createMaterializeDeviceEncodingPass)
- // TODO: Remove the following pass the plumb support for
- // #hal.descriptor_type memory space through the stack.
- .addPass(createEraseHALDescriptorTypeFromMemRefPass);
- modulePassManager.addPass(createVMVXSelectLoweringStrategyPass());
-}
-
-// ---------------------------------------------------------------------------
-// Variant Translation
+// Module-scope translation
// ---------------------------------------------------------------------------
static void
-buildVectorVMVXTransformPassPipeline(OpPassManager &variantPassManager) {
- variantPassManager.addPass(createCreateDispatchConfigPass());
-
+buildVectorVMVXTransformPassPipeline(OpPassManager &modulePassManager) {
// ---------------------------------------------------------------------------
// Tensor-level optimization, kernel dispatch and lower to buffers.
// ---------------------------------------------------------------------------
- {
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- FunctionLikeNest(modulePassManager)
- .addPass(createVMVXLowerExecutableTargetPass);
-
- // Resolve workgroup distribution before lowering ukernels to calls.
- // CPULowerToUKernelsPass (inside VMVXLowerExecutableTargetPass) lowers
- // iree_codegen.query_tile_sizes to iree_codegen.ukernel.generic which is
- // memory-effect-free at the tensor level (no memref operands). The WAR
- // hack in materializeSliceFromOrdinals replaces it with a constant in
- // the count region. After LowerUKernelOpsToCallsPass it becomes a
- // func.call which is memory-effecting and would be rejected by the
- // backward slice.
- modulePassManager.addPass(createReconcileTranslationInfoPass());
- modulePassManager.addPass(createResolveWorkgroupCountHintsPass());
- }
- variantPassManager.addPass(createPropagateDispatchConfigPass());
+ buildVMVXLoweringPassPipeline(modulePassManager);
// ---------------------------------------------------------------------------
// Linalg -> Vectors
// ---------------------------------------------------------------------------
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerUKernelOpsToCallsPass());
FunctionLikeNest(modulePassManager)
@@ -143,18 +101,17 @@
.addPass(createIREELoopInvariantCodeMotionPass);
}
-void buildVMVXTransformPassPipeline(OpPassManager &variantPassManager) {
+void buildVMVXTransformPassPipeline(OpPassManager &modulePassManager) {
// ---------------------------------------------------------------------------
// Linalg -> Scalars/Vectors
// ---------------------------------------------------------------------------
- buildVectorVMVXTransformPassPipeline(variantPassManager);
+ buildVectorVMVXTransformPassPipeline(modulePassManager);
// ---------------------------------------------------------------------------
// Standard/Vector/HAL/etc -> VMVX conversion
// ---------------------------------------------------------------------------
- OpPassManager &modulePassManager = variantPassManager.nest<mlir::ModuleOp>();
modulePassManager.addPass(createMaterializeConstantsPass());
modulePassManager.addPass(createConversionPass());
@@ -190,8 +147,8 @@
static PassPipelineRegistration<> transformPassPipeline(
"iree-vmvx-transformation-pipeline",
"Runs the full IREE VMVX dialect transformation pipeline",
- [](OpPassManager &variantPassManager) {
- buildVMVXTransformPassPipeline(variantPassManager);
+ [](OpPassManager &modulePassManager) {
+ buildVMVXTransformPassPipeline(modulePassManager);
});
}
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.h
index 4f1e2bc..fa3fbd0 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.h
@@ -20,22 +20,19 @@
// Helpers
//===----------------------------------------------------------------------===//
-// Adds a set of passes to the given pass manager that configure the required
-// VMVX transforms and tiling parameters.
-void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager);
-
// Adds a set of passes to the given pass manager that run the required VMVX
-// transforms in the canonical order.
+// transforms in the canonical order. The `modulePassManager` should operate
+// on the module within the IREE::HAL::ExecutableOp.
//
// Most translation code should prefer to use this instead of manually adding
// the passes themselves to ensure that expected pass ordering is observed.
//
// The expected usage is:
// <run conversion from TF/HLO/etc to flow>
-// buildVMVXConfigurationPassPipeline & run
+// buildVMVXCodegenConfigurationPassPipeline & run
// buildVMVXTransformPassPipeline & run
// <serialize VM module>
-void buildVMVXTransformPassPipeline(OpPassManager &variantPassManager);
+void buildVMVXTransformPassPipeline(OpPassManager &modulePassManager);
//===----------------------------------------------------------------------===//
// Register all Passes
diff --git a/tests/compiler_driver/hal_executable.mlir b/tests/compiler_driver/hal_executable.mlir
index 532cec9..016c888 100644
--- a/tests/compiler_driver/hal_executable.mlir
+++ b/tests/compiler_driver/hal_executable.mlir
@@ -21,7 +21,7 @@
// Exported functions are declared with the layout they use and may optionally
// contain other information - though when hand-authoring that's usually
// omitted.
- hal.executable.export public @mul layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
+ hal.executable.export public @mul ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}