Remove IREE usage of the Global Dialect Registry (#3036)
Mostly solves #2958, but there are some TODOs left because this doesn't
handle our pipeline-within-a-pipeline workaround for the absence of
dynamic pass registration very gracefully. Hopefully #1036 will solve
some of these issues.
diff --git a/bindings/python/pyiree/compiler/BUILD b/bindings/python/pyiree/compiler/BUILD
index dca46ed..8e676af 100644
--- a/bindings/python/pyiree/compiler/BUILD
+++ b/bindings/python/pyiree/compiler/BUILD
@@ -90,6 +90,7 @@
"//iree/tools:init_iree_passes_and_dialects",
"//iree/tools:init_mlir_passes_and_dialects",
"//iree/tools:init_targets",
+ "//iree/tools:init_xla_dialects",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
diff --git a/bindings/python/pyiree/compiler/CMakeLists.txt b/bindings/python/pyiree/compiler/CMakeLists.txt
index b47cb6f..4634b52 100644
--- a/bindings/python/pyiree/compiler/CMakeLists.txt
+++ b/bindings/python/pyiree/compiler/CMakeLists.txt
@@ -62,6 +62,7 @@
iree::tools::init_iree_passes_and_dialects
iree::tools::init_mlir_passes_and_dialects
iree::tools::init_targets
+ iree::tools::init_xla_dialects
LLVMSupport
MLIRIR
MLIRSCFTransforms
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 2bae670..747b43b 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -31,6 +31,7 @@
#include "iree/tools/init_mlir_dialects.h"
#include "iree/tools/init_mlir_passes.h"
#include "iree/tools/init_targets.h"
+#include "iree/tools/init_xla_dialects.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
@@ -38,6 +39,7 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
@@ -73,13 +75,6 @@
llvm::sys::DefaultOneShotPipeSignalHandler);
llvm::sys::PrintStackTraceOnErrorSignal("pyiree");
- mlir::enableGlobalDialectRegistry(true);
- // Register built-in MLIR dialects.
- mlir::registerMlirDialects();
-
- // Register IREE dialects, compiler module dialects, and HAL target backends.
- mlir::iree_compiler::registerIreeDialects();
- mlir::iree_compiler::registerIreeCompilerModuleDialects();
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();
@@ -98,6 +93,13 @@
return true;
}
+void registerDialects(DialectRegistry& registry) {
+ mlir::registerMlirDialects(registry);
+ mlir::registerXLADialects(registry);
+ mlir::iree_compiler::registerIreeDialects(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
+}
+
void SetupLLVMModule(pybind11::module m) {
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
m.def(
@@ -286,7 +288,7 @@
CompilerContextBundle::CompilerContextBundle()
: default_capture_(&mlir_context_, nullptr) {
- mlir_context_.loadAllGloballyRegisteredDialects();
+ registerDialects(mlir_context_.getDialectRegistry());
}
CompilerContextBundle::~CompilerContextBundle() = default;
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index 2a2fb6b..dc54aac 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -72,7 +72,7 @@
const int height = 4;
const int width = 4;
StringLiteral funcName = "kernel_matmul";
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
auto typeA = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
auto typeB = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
diff --git a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
index 9c43915..7859673 100644
--- a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
@@ -43,7 +43,7 @@
template <unsigned vecSize>
void testVectorAdd1d() {
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
constexpr int workgroupSize = 32;
auto typeA = modelBuilder.getMemRefType(vecSize, modelBuilder.f32);
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index c7df737..6ad02c1 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -89,7 +89,7 @@
// Simple test a single warp.
const int width = warpSize;
StringLiteral funcName = "kernel_vecadd";
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
auto nVectorType = modelBuilder.getVectorType(width, modelBuilder.f32);
auto typeA = modelBuilder.getMemRefType({width}, modelBuilder.f32);
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 2ddfc9f..ce83ed5 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -56,9 +56,25 @@
cc_binary(
name = "iree-tf-opt",
+ srcs = ["tf_opt_main.cc"],
deps = [
":tensorflow",
- "//iree/tools:iree_opt_main",
+ "//integrations/tensorflow/compiler/dialect/tf_strings/ir:dialect",
+ "//integrations/tensorflow/compiler/dialect/tf_tensorlist/ir:tf_tensorlist_dialect",
+ "//iree/compiler/Conversion:init_conversions",
+ "//iree/compiler/Conversion/HLOToLinalg",
+ "//iree/compiler/Dialect/HAL/Conversion:Passes",
+ "//iree/tools:init_compiler_modules",
+ "//iree/tools:init_iree_passes_and_dialects",
+ "//iree/tools:init_mlir_passes_and_dialects",
+ "//iree/tools:init_targets",
+ "//iree/tools:init_xla_dialects",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MlirOptLib",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
diff --git a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
index c97c52b..40aad21 100644
--- a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
@@ -16,10 +16,12 @@
#include "iree/base/signature_mangle.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/SymbolTable.h"
@@ -164,6 +166,10 @@
: public PassWrapper<TFSavedModelLowerGlobalTensors,
OperationPass<ModuleOp>> {
public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect, IREEDialect>();
+ }
+
void runOnOperation() override {
if (failed(importTfSavedModelGlobalTensorsToIREEFlow(getOperation()))) {
signalPassFailure();
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_flow_to_hal.h b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_flow_to_hal.h
index 4a79ef5..1d623d6 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_flow_to_hal.h
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_flow_to_hal.h
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/Modules/Strings/IR/Dialect.h"
#include "iree/compiler/Dialect/Modules/Strings/IR/Types.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -37,7 +38,11 @@
// use tensor types.
class TfStringsToHALConversionInterface : public HALConversionDialectInterface {
public:
- using HALConversionDialectInterface::HALConversionDialectInterface;
+ TfStringsToHALConversionInterface(Dialect *dialect)
+ : HALConversionDialectInterface(dialect) {
+ dialect->getContext()->loadDialect<IREE::Strings::StringsDialect>();
+ }
+
void setupConversionTarget(ConversionTarget &target,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) const override {
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
index f70fcb2..e5c1a3e 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
@@ -21,6 +21,7 @@
#include "integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@@ -89,6 +90,10 @@
: public PassWrapper<LowerTensorflowToStringsPass,
OperationPass<ModuleOp>> {
public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<TFStringsDialect>();
+ }
+
void runOnOperation() override {
if (failed(run())) {
signalPassFailure();
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.h b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.h
index 478479b..dafeb8d 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.h
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.h
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.h"
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -36,7 +37,10 @@
class TfTensorListToHALConversionInterface
: public HALConversionDialectInterface {
public:
- using HALConversionDialectInterface::HALConversionDialectInterface;
+ TfTensorListToHALConversionInterface(Dialect *dialect)
+ : HALConversionDialectInterface(dialect) {
+ dialect->getContext()->loadDialect<IREE::TensorList::TensorListDialect>();
+ }
void setupConversionTarget(ConversionTarget &target,
OwningRewritePatternList &patterns,
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index da966a3..472b890 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -14,6 +14,7 @@
#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_dialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -27,6 +28,9 @@
class ConvertTfToTfTensorList
: public PassWrapper<ConvertTfToTfTensorList, OperationPass<FuncOp>> {
public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<TfTensorListDialect>();
+ }
void runOnOperation() override;
};
diff --git a/integrations/tensorflow/compiler/tf_opt_main.cc b/integrations/tensorflow/compiler/tf_opt_main.cc
new file mode 100644
index 0000000..eb3efc2
--- /dev/null
+++ b/integrations/tensorflow/compiler/tf_opt_main.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Main entry function for iree-tf-opt and derived binaries.
+//
+// Based on iree-opt with the addition of TF dialects and passes
+
+#include "integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.h"
+#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_dialect.h"
+#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/init_conversions.h"
+#include "iree/compiler/Dialect/HAL/Conversion/Passes.h"
+#include "iree/tools/init_compiler_modules.h"
+#include "iree/tools/init_iree_dialects.h"
+#include "iree/tools/init_iree_passes.h"
+#include "iree/tools/init_mlir_dialects.h"
+#include "iree/tools/init_mlir_passes.h"
+#include "iree/tools/init_targets.h"
+#include "iree/tools/init_xla_dialects.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/MlirOptMain.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+
+#ifdef IREE_HAVE_EMITC_DIALECT
+#include "emitc/InitDialect.h"
+#endif // IREE_HAVE_EMITC_DIALECT
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+static llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+
+static llvm::cl::opt<bool> splitInputFile(
+ "split-input-file",
+ llvm::cl::desc("Split the input file into pieces and process each "
+ "chunk independently"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> verifyDiagnostics(
+ "verify-diagnostics",
+ llvm::cl::desc("Check that emitted diagnostics match "
+ "expected-* lines on the corresponding line"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> verifyPasses(
+ "verify-each",
+ llvm::cl::desc("Run the verifier after each transformation pass"),
+ llvm::cl::init(true));
+
+static llvm::cl::opt<bool> allowUnregisteredDialects(
+ "allow-unregistered-dialect",
+ llvm::cl::desc("Allow operation with no registered dialects"),
+ llvm::cl::init(true));
+
+static llvm::cl::opt<bool> showDialects(
+ "show-dialects", llvm::cl::desc("Print the list of registered dialects"),
+ llvm::cl::init(false));
+
+void registerTFDialects(mlir::DialectRegistry ®istry) {
+ registry.insert<mlir::TF::TensorFlowDialect,
+ mlir::tf_executor::TensorFlowExecutorDialect,
+ mlir::tf_device::TensorFlowDeviceDialect,
+ mlir::tf_saved_model::TensorFlowSavedModelDialect>();
+}
+
+void registerExtensionDialects(mlir::DialectRegistry ®istry) {
+ registry.insert<mlir::iree_compiler::tf_strings::TFStringsDialect,
+ mlir::tf_tensorlist::TfTensorListDialect>();
+}
+
+int main(int argc, char **argv) {
+ // TODO(#2958): There's a lot of duplication with iree-opt here. Factor out
+ // the common functionality.
+ llvm::InitLLVM y(argc, argv);
+
+ mlir::DialectRegistry registry;
+ mlir::registerMlirDialects(registry);
+ mlir::registerMlirPasses();
+#ifdef IREE_HAVE_EMITC_DIALECT
+ mlir::registerEmitCDialect(registry);
+#endif // IREE_HAVE_EMITC_DIALECT
+ mlir::registerXLADialects(registry);
+ mlir::iree_compiler::registerIreeDialects(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
+ registerTFDialects(registry);
+ registerExtensionDialects(registry);
+
+ mlir::iree_compiler::registerAllIreePasses();
+ mlir::iree_compiler::registerHALConversionPasses();
+ mlir::iree_compiler::registerHALTargetBackends();
+ mlir::iree_compiler::registerLinalgToSPIRVPasses();
+ mlir::iree_compiler::registerHLOToLinalgPasses();
+ mlir::iree_compiler::registerLinalgToLLVMPasses();
+
+ // Register MLIRContext command-line options like
+ // -mlir-print-op-on-diagnostic.
+ mlir::registerMLIRContextCLOptions();
+ // Register assembly printer command-line options like
+ // -mlir-print-op-generic.
+ mlir::registerAsmPrinterCLOptions();
+ // Register pass manager command-line options like -print-ir-*.
+ mlir::registerPassManagerCLOptions();
+
+ mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
+
+ // Parse pass names in main to ensure static initialization completed.
+ llvm::cl::ParseCommandLineOptions(argc, argv,
+ "IREE modular optimizer driver\n");
+
+ if (showDialects) {
+ llvm::outs() << "Available Dialects:\n";
+ interleave(
+ registry, llvm::outs(),
+ [](auto ®istryEntry) { llvm::outs() << registryEntry.first; }, "\n");
+ return 0;
+ }
+
+ // Set up the input file.
+ std::string errorMessage;
+ auto file = mlir::openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ auto output = mlir::openOutputFile(outputFilename, &errorMessage);
+ if (!output) {
+ llvm::errs() << errorMessage << "\n";
+ exit(1);
+ }
+
+ if (failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
+ registry, splitInputFile, verifyDiagnostics,
+ verifyPasses, allowUnregisteredDialects,
+ /*preloadDialectsInContext=*/false))) {
+ return 1;
+ }
+}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 6aa0836..71f550f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -206,8 +206,12 @@
LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
+ // clang-format off
+ registry.insert<AffineDialect,
+ gpu::GPUDialect,
+ linalg::LinalgDialect,
scf::SCFDialect>();
+ // clang-format on
}
void runOnFunction() override;
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index c6aa174..7721f41 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -52,8 +52,12 @@
struct ConvertVectorToGPUPass
: public PassWrapper<ConvertVectorToGPUPass, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, gpu::GPUDialect, scf::SCFDialect,
+ // clang-format off
+ registry.insert<AffineDialect,
+ gpu::GPUDialect,
+ scf::SCFDialect,
vector::VectorDialect>();
+ // clang-format on
}
void runOnOperation() override;
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 3392705..e3ce6a7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -56,9 +56,9 @@
"//iree/compiler/Dialect/Shape/Utils:TypeConversion",
"//iree/compiler/Utils",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 97164dc..4cdf878 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -39,9 +39,9 @@
"RematerializeDispatchConstants.cpp"
DEPS
LLVMSupport
- MLIRAnalysis
MLIRIR
MLIRPass
+ MLIRShape
MLIRShapeOpsTransforms
MLIRStandardOps
MLIRSupport
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 0425ee0..496f22a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
@@ -374,6 +375,10 @@
struct HLOToHLOPreprocessing
: public PassWrapper<HLOToHLOPreprocessing, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
+ }
+
void runOnFunction() override {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index bbb82b0..3728d92 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -44,8 +44,11 @@
"//iree/schemas:llvmir_executable_def_cc_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
- # TODO(ataei): Link with native target dep.
- "@llvm-project//llvm:X86CodeGen",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TargetLLVMIR",
+ "@llvm-project//mlir:VectorOps",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index 6b1604f..339f128 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -28,8 +28,12 @@
DEPS
LLVMCore
LLVMSupport
- LLVMX86CodeGen
+ MLIRAffineOps
+ MLIRLLVMIR
+ MLIRLinalgOps
+ MLIRSCF
MLIRTargetLLVMIR
+ MLIRVector
iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index 87f32fd..26037b2 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -22,6 +22,11 @@
#include "llvm/IR/Module.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/TargetSelect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Target/LLVMIR.h"
namespace mlir {
@@ -37,6 +42,16 @@
// NOTE: we could vary this based on the options, such as by arch/etc.
std::string name() const override { return "llvm-ir*"; }
+ void getDependentDialects(DialectRegistry& registry) const override {
+ // clang-format off
+ registry.insert<AffineDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ scf::SCFDialect,
+ vector::VectorDialect>();
+ // clang-format on
+ }
+
void buildTranslationPassPipeline(ExecutableTargetOp targetOp,
OpPassManager& passManager) override {
buildLLVMTransformPassPipeline(passManager);
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index 67be939..5b25c7c 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -17,6 +17,7 @@
#include <algorithm>
#include "llvm/Support/CommandLine.h"
+#include "mlir/IR/Dialect.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 59c6c8c..37d74a3 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -25,6 +25,7 @@
#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
@@ -149,6 +150,16 @@
virtual void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp);
+ // Register dependent dialects for the TargetBackend.
+ // Mirrors the method on mlir::Pass of the same name. A TargetBackend is
+ // expected to register the dialects it will create entities for (Operations,
+ // Types, Attributes), other than dialects that exist in the input. These are
+ // the dialects that will be used in |declareTargetOps| and
+ // |buildTranslationPassPipeline|.
+ // TODO(#1036): We might be able to get rid of this with dynamic pass
+ // registration.
+ virtual void getDependentDialects(DialectRegistry ®istry) const {}
+
// Captured state from the point at which a dispatch is to be recorded.
struct DispatchState {
// The original flow.dispatch op.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
index 99bcc9f..f3f0f08 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
@@ -40,8 +40,10 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/VM/Conversion",
+ "//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Target/Bytecode",
"//iree/compiler/Dialect/VM/Transforms",
+ "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"//iree/compiler/Dialect/VMLA/Transforms",
"//iree/schemas:vmla_executable_def_cc_fbs",
"@com_github_google_flatbuffers//:flatbuffers",
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
index a58e645..43de52b 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
@@ -34,8 +34,10 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::VM::Conversion
+ iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Target::Bytecode
iree::compiler::Dialect::VM::Transforms
+ iree::compiler::Dialect::VMLA::IR::VMLADialect
iree::compiler::Dialect::VMLA::Transforms
iree::schemas::vmla_executable_def_cc_fbs
PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 22660d8..3db3986 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -18,8 +18,10 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "iree/schemas/vmla_executable_def_generated.h"
#include "llvm/ADT/ScopeExit.h"
@@ -47,6 +49,10 @@
std::string name() const override { return "vmla"; }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VM::VMDialect, VMLA::VMLADialect>();
+ }
+
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
OpPassManager &passManager) override {
IREE::VMLA::buildVMLATransformPassPipeline(passManager);
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index f4bdf08..18006a7 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -41,13 +41,15 @@
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
- "//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Dialect/Vulkan/Utils",
"//iree/schemas:spirv_executable_def_cc_fbs",
"@com_github_google_flatbuffers//:flatbuffers",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SPIRVDialect",
@@ -55,6 +57,7 @@
"@llvm-project//mlir:SPIRVSerialization",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorOps",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 5764276..05d2710 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -27,7 +27,10 @@
"VulkanSPIRVTarget.cpp"
DEPS
LLVMSupport
+ MLIRAffineOps
+ MLIRGPU
MLIRIR
+ MLIRLinalgOps
MLIRParser
MLIRPass
MLIRSPIRV
@@ -35,12 +38,12 @@
MLIRSPIRVTransforms
MLIRSupport
MLIRTransforms
+ MLIRVector
flatbuffers
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Dialect::Vulkan::Utils
iree::schemas::spirv_executable_def_cc_fbs
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index a183357..4d3d464 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -22,16 +22,23 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
#include "iree/schemas/spirv_executable_def_generated.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
@@ -229,6 +236,18 @@
// NOTE: we could vary this based on the options such as 'vulkan-v1.1'.
std::string name() const override { return "vulkan*"; }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // clang-format off
+ registry.insert<AffineDialect,
+ Vulkan::VulkanDialect,
+ gpu::GPUDialect,
+ linalg::LinalgDialect,
+ scf::SCFDialect,
+ spirv::SPIRVDialect,
+ vector::VectorDialect>();
+ // clang-format on
+ }
+
void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) override {
OpBuilder targetBuilder(&executableOp.getBlock().back());
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 689d530..ca98ca8 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -43,7 +43,6 @@
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/Transforms",
- "//iree/compiler/Utils",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index a486db3..d74601a 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -45,7 +45,6 @@
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::Transforms
- iree::compiler::Utils
tensorflow::mlir_hlo
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 79680d8..441d55b 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -329,6 +329,11 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::HAL::HALDialect>();
+
+ auto targetBackends = matchTargetBackends(targetOptions_.targets);
+ for (auto &targetBackend : targetBackends) {
+ targetBackend->getDependentDialects(registry);
+ }
}
void runOnOperation() override {
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 1d58ef6..2f0a026 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -17,6 +17,7 @@
#include <memory>
#include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
@@ -49,6 +50,11 @@
// directly call TargetBackend::buildTranslationPassPipeline function. For now
// we need to run each backend translation in isolation and we do that within
// this pass.
+ // The createTranslateExecutablesPass operates on hal.executable ops, so
+ // requires that the dialect already be loaded before it can be added to the
+ // pass pipeline.
+ // TODO(#2958): This shouldn't be necessary.
+ passManager.getContext()->loadDialect<HALDialect>();
passManager.addPass(createTranslateExecutablesPass(targetOptions));
// After all executables are translated we allow the backends to link them
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 4ba6cfc..a0e936b 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -45,7 +45,7 @@
// buildHALTransformPassPipeline & run
// <run conversion from HAL to vm/etc>
void buildHALTransformPassPipeline(OpPassManager &passManager,
- TargetOptions executableOptions);
+ TargetOptions targetOptions);
void registerHALTransformPassPipeline();
diff --git a/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp b/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
index 8dd74e4..954dbb7 100644
--- a/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
@@ -14,6 +14,7 @@
#include <utility>
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
@@ -38,6 +39,15 @@
explicit TranslateExecutablesPass(TargetOptions executableOptions)
: executableOptions_(executableOptions) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<HALDialect>();
+
+ auto targetBackends = matchTargetBackends(executableOptions_.targets);
+ for (auto &targetBackend : targetBackends) {
+ targetBackend->getDependentDialects(registry);
+ }
+ }
+
void runOnOperation() override {
auto executableOp = getOperation();
auto targetOps = llvm::to_vector<4>(
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index f48923e..1edfd36 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -32,6 +32,7 @@
],
deps = [
"//iree/compiler/Dialect/IREE/Conversion:PreserveCompilerHints",
+ "//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/VM/Conversion",
"//iree/compiler/Dialect/VM/Conversion/IREEToVM",
diff --git a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 5472647..413d33c 100644
--- a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -33,6 +33,7 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::IREE::Conversion::PreserveCompilerHints
+ iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Conversion::IREEToVM
diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 3f158b6..46081ed 100644
--- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -15,6 +15,7 @@
#include <tuple>
#include "iree/compiler/Dialect/IREE/Conversion/PreserveCompilerHints.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
@@ -83,7 +84,7 @@
: targetOptions_(targetOptions) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<StandardOpsDialect, IREE::VM::VMDialect>();
+ registry.insert<IREEDialect, IREE::VM::VMDialect, StandardOpsDialect>();
}
void runOnOperation() override {
diff --git a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index 5d16de6..bb87505 100644
--- a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -180,7 +180,8 @@
}
};
-std::unique_ptr<OperationPass<ModuleOp>> createGlobalInitializationPass() {
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
+createGlobalInitializationPass() {
return std::make_unique<GlobalInitializationPass>();
}
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 1f84c7e..9f087a2 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -30,6 +30,11 @@
TargetOptions targetOptions) {
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createConversionPass(targetOptions));
+ // The createGlobalInitializationPass operates on vm.module ops, so requires
+ // that the dialect already be loaded before it can be added to the pass
+ // pipeline.
+ // TODO(#2958): This shouldn't be necessary.
+ passManager.getContext()->loadDialect<VM::VMDialect>();
passManager.addPass(createGlobalInitializationPass());
passManager.addPass(createInlinerPass());
passManager.addPass(createCSEPass());
diff --git a/iree/samples/custom_modules/dialect/custom_opt.cc b/iree/samples/custom_modules/dialect/custom_opt.cc
index 77969ad..70cc975 100644
--- a/iree/samples/custom_modules/dialect/custom_opt.cc
+++ b/iree/samples/custom_modules/dialect/custom_opt.cc
@@ -74,7 +74,6 @@
llvm::cl::init(false));
int main(int argc, char **argv) {
- mlir::enableGlobalDialectRegistry(true);
mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
mlir::registerMlirPasses();
@@ -131,7 +130,7 @@
if (failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
registry, splitInputFile, verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- /*preloadDialectsInContext=*/true))) {
+ /*preloadDialectsInContext=*/false))) {
return 1;
}
}
diff --git a/iree/samples/custom_modules/dialect/custom_translate.cc b/iree/samples/custom_modules/dialect/custom_translate.cc
index d913857..2361a0e 100644
--- a/iree/samples/custom_modules/dialect/custom_translate.cc
+++ b/iree/samples/custom_modules/dialect/custom_translate.cc
@@ -56,14 +56,15 @@
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
- mlir::enableGlobalDialectRegistry(true);
- mlir::registerMlirDialects();
- mlir::registerXLADialects();
- mlir::iree_compiler::registerIreeDialects();
+ mlir::DialectRegistry registry;
+
+ mlir::registerMlirDialects(registry);
+ mlir::registerXLADialects(registry);
+ mlir::iree_compiler::registerIreeDialects(registry);
// Register the custom dialect
- mlir::iree_compiler::registerCustomDialect();
- mlir::iree_compiler::registerIreeCompilerModuleDialects();
+ mlir::iree_compiler::registerCustomDialect(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();
mlir::registerMlirTranslations();
@@ -102,8 +103,8 @@
/// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream &os) {
- mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ registry.appendTo(context.getDialectRegistry());
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index c27bdf6..14e6c17 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -240,6 +240,7 @@
":init_iree_passes_and_dialects",
":init_mlir_passes_and_dialects",
":init_targets",
+ ":init_xla_dialects",
":vm_util",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
@@ -258,7 +259,6 @@
"//iree/vm:bytecode_module",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
- "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index f3a33a5..015fe37 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -322,10 +322,10 @@
::init_iree_passes_and_dialects
::init_mlir_passes_and_dialects
::init_targets
+ ::init_xla_dialects
::vm_util
LLVMSupport
MLIRIR
- MLIRSCFTransforms
MLIRParser
MLIRPass
MLIRSupport
diff --git a/iree/tools/opt_main.cc b/iree/tools/opt_main.cc
index 5e6af4b..8e8a75b 100644
--- a/iree/tools/opt_main.cc
+++ b/iree/tools/opt_main.cc
@@ -76,7 +76,8 @@
llvm::cl::init(false));
int main(int argc, char **argv) {
- mlir::enableGlobalDialectRegistry(true);
+ llvm::InitLLVM y(argc, argv);
+
mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
mlir::registerMlirPasses();
@@ -92,7 +93,6 @@
mlir::iree_compiler::registerLinalgToSPIRVPasses();
mlir::iree_compiler::registerHLOToLinalgPasses();
mlir::iree_compiler::registerLinalgToLLVMPasses();
- llvm::InitLLVM y(argc, argv);
// Register MLIRContext command-line options like
// -mlir-print-op-on-diagnostic.
@@ -131,10 +131,12 @@
exit(1);
}
+ // TODO(#2958): There's a simpler version of MlirOptMain we should be able to
+ // use.
if (failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
registry, splitInputFile, verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- /*preloadDialectsInContext=*/true))) {
+ /*preloadDialectsInContext=*/false))) {
return 1;
}
}
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
index bf2b5c5..98c4329 100644
--- a/iree/tools/run_mlir_main.cc
+++ b/iree/tools/run_mlir_main.cc
@@ -61,6 +61,7 @@
#include "iree/tools/init_iree_dialects.h"
#include "iree/tools/init_mlir_dialects.h"
#include "iree/tools/init_targets.h"
+#include "iree/tools/init_xla_dialects.h"
#include "iree/tools/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -70,6 +71,7 @@
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
@@ -173,12 +175,12 @@
// Prepares a module for evaluation by running MLIR import and IREE translation.
// Returns the serialized flatbuffer data.
StatusOr<std::string> PrepareModule(
- std::string target_backend,
- std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
+ std::string target_backend, std::unique_ptr<llvm::MemoryBuffer> file_buffer,
+ mlir::DialectRegistry& registry) {
IREE_TRACE_SCOPE0("PrepareModule");
- mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ registry.appendTo(context.getDialectRegistry());
// Parse input MLIR module.
llvm::SourceMgr source_mgr;
@@ -387,7 +389,8 @@
}
// Translates and runs a single LLVM file buffer.
-Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
+Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer,
+ mlir::DialectRegistry& registry) {
IREE_TRACE_SCOPE0("EvaluateFile");
// TODO(benvanik): move to instance-based registration.
@@ -407,7 +410,8 @@
file_buffer->getBuffer(), file_buffer->getBufferIdentifier());
IREE_ASSIGN_OR_RETURN(
auto flatbuffer_data,
- PrepareModule(target_backend + '*', std::move(cloned_file_buffer)),
+ PrepareModule(target_backend + '*', std::move(cloned_file_buffer),
+ registry),
_ << "Translating module");
IREE_TRACE_FRAME_MARK();
IREE_RETURN_IF_ERROR(EvaluateFunctions(
@@ -420,7 +424,8 @@
}
// Runs the given .mlir file based on the current flags.
-Status RunFile(const std::string& mlir_filename) {
+Status RunFile(const std::string& mlir_filename,
+ mlir::DialectRegistry& registry) {
IREE_TRACE_SCOPE0("RunFile");
// Load input file/from stdin.
@@ -434,7 +439,7 @@
if (!split_input_file_flag) {
// Use entire buffer as a single module.
- return EvaluateFile(std::move(file));
+ return EvaluateFile(std::move(file), registry);
}
// Split the buffer into separate modules and evaluate independently.
@@ -457,7 +462,7 @@
sub_source_buffer, full_buffer->getBufferIdentifier() +
llvm::Twine(" split at line #") +
llvm::Twine(split_line));
- auto sub_failure = EvaluateFile(std::move(sub_buffer));
+ auto sub_failure = EvaluateFile(std::move(sub_buffer), registry);
if (!sub_failure.ok()) {
LOG(ERROR) << "Failure for split at line #" << split_line << ": "
<< sub_failure;
@@ -490,10 +495,11 @@
}
}
- mlir::enableGlobalDialectRegistry(true);
- mlir::registerMlirDialects();
- mlir::iree_compiler::registerIreeDialects();
- mlir::iree_compiler::registerIreeCompilerModuleDialects();
+ mlir::DialectRegistry registry;
+ mlir::registerMlirDialects(registry);
+ mlir::iree_compiler::registerIreeDialects(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
+ mlir::registerXLADialects(registry);
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();
@@ -516,7 +522,7 @@
char** argv_absl_ptr = argv_absl.data();
iree::InitializeEnvironment(&argc_absl, &argv_absl_ptr);
- auto status = RunFile(input_file_flag);
+ auto status = RunFile(input_file_flag, registry);
if (!status.ok()) {
std::cerr << "ERROR running file (" << input_file_flag << "): " << status
<< "\n";
diff --git a/iree/tools/translate_main.cc b/iree/tools/translate_main.cc
index 6f251ed..d5d48c4 100644
--- a/iree/tools/translate_main.cc
+++ b/iree/tools/translate_main.cc
@@ -57,17 +57,22 @@
"process each chunk independently"),
llvm::cl::init(false));
+// TODO(#2958): We shouldn't need to register dialects here if translations
+// correctly declare the dialects they support.
+// TODO(#2958): Investigate whether we can use mlir-translate.cpp as an entry
+// point.
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
- mlir::enableGlobalDialectRegistry(true);
-
- mlir::registerMlirDialects();
+ // TODO(#2958): We shouldn't need to register dialects here if translations
+ // correctly declare the dialects they support.
+ mlir::DialectRegistry registry;
+ mlir::registerMlirDialects(registry);
#ifdef IREE_HAVE_EMITC_DIALECT
- mlir::registerEmitCDialect();
+ mlir::registerEmitCDialect(registry);
#endif // IREE_HAVE_EMITC_DIALECT
- mlir::registerXLADialects();
- mlir::iree_compiler::registerIreeDialects();
- mlir::iree_compiler::registerIreeCompilerModuleDialects();
+ mlir::registerXLADialects(registry);
+ mlir::iree_compiler::registerIreeDialects(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();
mlir::registerMlirTranslations();
@@ -109,8 +114,8 @@
/// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream &os) {
- mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ registry.appendTo(context.getDialectRegistry());
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);