Route demotion flag to Input options (#13993)
Demotion should be configuration via the shared object file. The
currentl flags are frontend specific. Rerouted the passes so it is
configurable via `setFlags` for the libIREECompile.so file.
diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
index 68a08b5..42a36c9 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
@@ -32,8 +32,13 @@
namespace {
struct AutoInputConversionPipelinePass final
: AutoInputConversionPipelineBase<AutoInputConversionPipelinePass> {
+ AutoInputConversionPipelinePass(
+ const AutoInputConversionPipelineOptions& inputOptions)
+ : options(inputOptions) {}
void runOnOperation() override;
void getDependentDialects(DialectRegistry& registry) const override;
+
+ AutoInputConversionPipelineOptions options;
};
// All the features seen that should be handled during input conversion.
@@ -154,10 +159,14 @@
OpPassManager::Nesting::Explicit);
#ifdef IREE_HAVE_MHLO_INPUT
if (features.hasStableHLO && !features.hasMHLO) {
+ stablehlo::StableHloOptions options;
+ options.demoteI64ToI32 = demoteI64ToI32;
+ options.demoteF64ToF32 = demoteF64ToF32;
+ options.promoteBF16ToF32 = promoteBF16ToF32;
if (features.hasTuples) {
- stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm);
+ stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm, options);
} else {
- stablehlo::buildStableHLOInputConversionPassPipeline(pm);
+ stablehlo::buildStableHLOInputConversionPassPipeline(pm, options);
}
}
if (features.hasMHLO) {
@@ -201,8 +210,19 @@
};
#ifdef IREE_HAVE_MHLO_INPUT
- appendPipelineDialects(stablehlo::buildStableHLOInputConversionPassPipeline);
- appendPipelineDialects(
+ auto appendStablehloPipelineDialects =
+ [®istry](function_ref<void(OpPassManager&,
+ const stablehlo::StableHloOptions& options)>
+ buildFn) {
+ const stablehlo::StableHloOptions options;
+ OpPassManager pm;
+ buildFn(pm, options);
+ pm.getDependentDialects(registry);
+ };
+
+ appendStablehloPipelineDialects(
+ stablehlo::buildStableHLOInputConversionPassPipeline);
+ appendStablehloPipelineDialects(
stablehlo::buildStableHLOXLAInputConversionPassPipeline);
appendPipelineDialects(MHLO::buildMHLOInputConversionPassPipeline);
@@ -224,7 +244,13 @@
std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass() {
- return std::make_unique<AutoInputConversionPipelinePass>();
+ AutoInputConversionPipelineOptions options;
+ return std::make_unique<AutoInputConversionPipelinePass>(options);
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createAutoInputConversionPipelinePass(
+ const AutoInputConversionPipelineOptions& options) {
+ return std::make_unique<AutoInputConversionPipelinePass>(options);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.h b/compiler/src/iree/compiler/InputConversion/Common/Passes.h
index dfbff2f..797a29c 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/Passes.h
+++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_
#define IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_
+#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -14,6 +15,9 @@
namespace mlir {
namespace iree_compiler {
+#define GEN_PASS_DECL
+#include "iree/compiler/InputConversion/Common/Passes.h.inc"
+
//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//
@@ -28,6 +32,8 @@
std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass();
+std::unique_ptr<OperationPass<ModuleOp>> createAutoInputConversionPipelinePass(
+ const AutoInputConversionPipelineOptions& options);
std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass();
std::unique_ptr<OperationPass<ModuleOp>> createImportMLProgramPass();
std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.td b/compiler/src/iree/compiler/InputConversion/Common/Passes.td
index f42e5ff..950cce6 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.td
@@ -53,6 +53,14 @@
conversion to run, then run that conversion.
}];
let constructor = "mlir::iree_compiler::createAutoInputConversionPipelinePass()";
+ let options = [
+ Option<"demoteI64ToI32", "iree-autoinput-demote-i64-to-i32", "bool",
+ /*default=*/"true", "Convert I64 to I32 equivalents">,
+ Option<"demoteF64ToF32", "iree-autoinput-demote-f64-to-f32", "bool",
+ /*default=*/"false", "Convert F64 to F32 equivalents">,
+ Option<"promoteBF16ToF32", "iree-autoinput-demote-bf16-to-f32", "bool",
+ /*default=*/"false", "Convert BF16 to F32 equivalents">,
+ ];
}
#endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
index d82ac75..6eef697 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
@@ -26,36 +26,19 @@
} // namespace
namespace {
-// TODO(#8745): remove these flags when the -iree-flow-demote-* flags can be
-// used without tripping upstream verifier issues.
-llvm::cl::opt<bool> clDemoteI64ToI32(
- "iree-stablehlo-demote-i64-to-i32",
- llvm::cl::desc(
- "Converts all StableHLO i64 ops and values into i32 counterparts."),
- llvm::cl::init(true));
-llvm::cl::opt<bool> clDemoteF64ToF32(
- "iree-stablehlo-demote-f64-to-f32",
- llvm::cl::desc(
- "Converts all StableHLO f64 ops and values into f32 counterparts."),
- llvm::cl::init(true));
-llvm::cl::opt<bool> clPromoteBF16ToF32(
- "iree-stablehlo-promote-bf16-to-f32",
- llvm::cl::desc(
- "Converts all StableHLO bf16 ops and values into f32 counterparts."),
- llvm::cl::init(false));
void registerStableHLOConversionPassPipeline() {
- PassPipelineRegistration<> stablehlo(
+ PassPipelineRegistration<StableHloOptions> stablehlo(
"iree-stablehlo-input-transformation-pipeline",
"Runs the StableHLO IREE flow dialect transformation pipeline",
- [](OpPassManager &passManager) {
- buildStableHLOInputConversionPassPipeline(passManager);
+ [](OpPassManager& passManager, const StableHloOptions& options) {
+ buildStableHLOInputConversionPassPipeline(passManager, options);
});
}
// Prepare HLO for use as an input to the Flow dialect.
-void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
- bool detuple) {
+void buildStableHLOInputConversionPassPipelineImpl(
+ OpPassManager& passManager, const StableHloOptions& options, bool detuple) {
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());
@@ -88,13 +71,13 @@
// stack. This is often required because of implicit i64 insertion by JAX/HLO
// that we don't want forcing 32-bit embedded devices to support.
// TODO(#8745): remove these and prefer the flow pipeline options instead.
- if (clDemoteI64ToI32) {
+ if (options.demoteI64ToI32) {
passManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
}
- if (clDemoteF64ToF32) {
+ if (options.demoteF64ToF32) {
passManager.addPass(IREE::Util::createDemoteF64ToF32Pass());
}
- if (clPromoteBF16ToF32) {
+ if (options.promoteBF16ToF32) {
passManager.addPass(IREE::Util::createPromoteBF16ToF32Pass());
}
@@ -123,12 +106,16 @@
}
} // namespace
-void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager) {
- buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/false);
+void buildStableHLOInputConversionPassPipeline(
+ OpPassManager& passManager, const StableHloOptions& options) {
+ buildStableHLOInputConversionPassPipelineImpl(passManager, options,
+ /*detuple=*/false);
}
-void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager) {
- buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/true);
+void buildStableHLOXLAInputConversionPassPipeline(
+ OpPassManager& passManager, const StableHloOptions& options) {
+ buildStableHLOInputConversionPassPipelineImpl(passManager, options,
+ /*detuple=*/true);
}
void registerStableHLOConversionPasses() {
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h
index 29cfc95..be8af77 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h
@@ -16,15 +16,23 @@
std::unique_ptr<TypeConverter> createStableHloToLinalgTypeConverter();
+struct StableHloOptions : public PassPipelineOptions<StableHloOptions> {
+ bool demoteI64ToI32 = true;
+ bool demoteF64ToF32 = false;
+ bool promoteBF16ToF32 = false;
+};
+
//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//
-void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager);
+void buildStableHLOInputConversionPassPipeline(OpPassManager& passManager,
+ const StableHloOptions& options);
// Performs input legalization on programs that may have originated from an XLA
// import (or made to interop with it).
-void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager);
+void buildStableHLOXLAInputConversionPassPipeline(
+ OpPassManager& passManager, const StableHloOptions& options);
//===----------------------------------------------------------------------===//
// Register all Passes
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index 5e0c309..904dc31 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -68,6 +68,23 @@
),
// clang-format on
llvm::cl::cat(category));
+
+#ifdef IREE_HAVE_MHLO_INPUT
+ binder.opt<bool>(
+ "iree-input-demote-i64-to-i32", demoteI64ToI32,
+ llvm::cl::desc("Converts all i64 ops and values into i32 counterparts."),
+ llvm::cl::cat(category));
+
+ binder.opt<bool>(
+ "iree-input-demote-f64-to-f32", demoteF64ToF32,
+ llvm::cl::desc("Converts all f64 ops and values into f32 counterparts."),
+ llvm::cl::cat(category));
+
+ binder.opt<bool>(
+ "iree-input-promote-bf16-to-f32", promoteBF16ToF32,
+ llvm::cl::desc("Converts all bf16 ops and values into f32 counterparts."),
+ llvm::cl::cat(category));
+#endif
}
void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) {
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 5005240..40b9504 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -59,6 +59,10 @@
};
Type type = Type::auto_detect;
+ bool demoteI64ToI32 = true;
+ bool demoteF64ToF32 = true;
+ bool promoteBF16ToF32 = true;
+
void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<InputDialectOptions>;
};
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index f9b7514..cc11538 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -51,18 +51,28 @@
hooks.pipelineExtensions->extendInputConversionPreprocessingPassPipeline(
passManager, inputOptions.type);
}
+ AutoInputConversionPipelineOptions autoOptions;
+
+#ifdef IREE_HAVE_MHLO_INPUT
+ stablehlo::StableHloOptions stablehloOptions;
+ stablehloOptions.demoteI64ToI32 = inputOptions.demoteI64ToI32;
+ stablehloOptions.demoteF64ToF32 = inputOptions.demoteF64ToF32;
+ stablehloOptions.promoteBF16ToF32 = inputOptions.promoteBF16ToF32;
+#endif
switch (inputOptions.type) {
case InputDialectOptions::Type::none:
break;
case InputDialectOptions::Type::auto_detect:
- passManager.addPass(createAutoInputConversionPipelinePass());
+ passManager.addPass(createAutoInputConversionPipelinePass(autoOptions));
break;
#ifdef IREE_HAVE_MHLO_INPUT
case InputDialectOptions::Type::stablehlo:
- stablehlo::buildStableHLOInputConversionPassPipeline(passManager);
+ stablehlo::buildStableHLOInputConversionPassPipeline(passManager,
+ stablehloOptions);
break;
case InputDialectOptions::Type::stablehlo_xla:
- stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager);
+ stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager,
+ stablehloOptions);
break;
case InputDialectOptions::Type::mhlo_legacy:
MHLO::buildMHLOInputConversionPassPipeline(passManager);
diff --git a/tests/e2e/vulkan_specific/BUILD.bazel b/tests/e2e/vulkan_specific/BUILD.bazel
index f917032..9111af4 100644
--- a/tests/e2e/vulkan_specific/BUILD.bazel
+++ b/tests/e2e/vulkan_specific/BUILD.bazel
@@ -53,7 +53,7 @@
],
compiler_flags = [
"--iree-input-type=stablehlo",
- "--iree-stablehlo-demote-i64-to-i32=false",
+ "--iree-input-demote-i64-to-i32=false",
"--iree-vulkan-target-triple=valhall-unknown-android31",
],
driver = "vulkan",
diff --git a/tests/e2e/vulkan_specific/CMakeLists.txt b/tests/e2e/vulkan_specific/CMakeLists.txt
index 3d0d9b3..65623d2 100644
--- a/tests/e2e/vulkan_specific/CMakeLists.txt
+++ b/tests/e2e/vulkan_specific/CMakeLists.txt
@@ -55,7 +55,7 @@
"vulkan"
COMPILER_FLAGS
"--iree-input-type=stablehlo"
- "--iree-stablehlo-demote-i64-to-i32=false"
+ "--iree-input-demote-i64-to-i32=false"
"--iree-vulkan-target-triple=valhall-unknown-android31"
LABELS
"manual"