Caching VM import modules so they don't need to be reloaded.
MLIR's parser code has high contention with a lock in the diagnostics
engine. Reloading the import modules from all threads running conversion
was taking 19sec in BERT; now it's ~0.
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index cd5002f..a29d4f1 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -54,7 +54,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(
StringRef(hal_imports_create()->data, hal_imports_create()->size),
getDialect()->getContext());
diff --git a/iree/compiler/Dialect/Modules/Check/IR/CheckDialect.cpp b/iree/compiler/Dialect/Modules/Check/IR/CheckDialect.cpp
index 0261e16..6b0cca6 100644
--- a/iree/compiler/Dialect/Modules/Check/IR/CheckDialect.cpp
+++ b/iree/compiler/Dialect/Modules/Check/IR/CheckDialect.cpp
@@ -33,7 +33,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(
StringRef(check_imports_create()->data, check_imports_create()->size),
getDialect()->getContext());
diff --git a/iree/compiler/Dialect/Modules/Strings/IR/Dialect.cc b/iree/compiler/Dialect/Modules/Strings/IR/Dialect.cc
index 3fb5a5f..cddea50 100644
--- a/iree/compiler/Dialect/Modules/Strings/IR/Dialect.cc
+++ b/iree/compiler/Dialect/Modules/Strings/IR/Dialect.cc
@@ -36,7 +36,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(StringRef(strings_imports_create()->data,
strings_imports_create()->size),
getDialect()->getContext());
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.cpp b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.cpp
index 75bf2e2..2a8fe02 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.cpp
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.cpp
@@ -37,7 +37,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(StringRef(tensorlist_imports_create()->data,
tensorlist_imports_create()->size),
getDialect()->getContext());
diff --git a/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
index 7d5563f..15abb8e 100644
--- a/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
+++ b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
@@ -15,6 +15,8 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
+#include <mutex>
+
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -34,7 +36,11 @@
// Returns a module containing one or more vm.modules with vm.import ops.
// These modules will be merged into the module being compiled to provide
// import definitions to the conversion and lowering process.
- virtual OwningModuleRef getVMImportModule() const = 0;
+ mlir::ModuleOp getVMImportModule() const {
+ std::call_once(importParseFlag,
+ [&]() { importModuleRef = parseVMImportModule(); });
+ return importModuleRef.get();
+ }
// Populates |patterns| with rewrites that convert from the implementation
// dialect to the VM dialect. Many of these can just be default conversions
@@ -50,6 +56,14 @@
virtual void walkAttributeStorage(
Attribute attr,
const function_ref<void(Attribute elementAttr)> &fn) const {}
+
+ protected:
+ // Parses the vm.import module to be cached by the caller.
+ virtual OwningModuleRef parseVMImportModule() const = 0;
+
+ private:
+ mutable std::once_flag importParseFlag;
+ mutable OwningModuleRef importModuleRef;
};
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 46081ed..1837326 100644
--- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -103,8 +103,16 @@
innerModuleOp);
for (auto *dialectInterface : usedDialects) {
auto outerImportModuleOp = dialectInterface->getVMImportModule();
+ if (!outerImportModuleOp) {
+ innerModuleOp.emitError()
+ << "unable load the VM import module for dialect '"
+ << dialectInterface->getDialect()->getNamespace()
+ << "'; possibly a bad file structure or malformed vm.import";
+ signalPassFailure();
+ return;
+ }
for (auto importModuleOp :
- outerImportModuleOp->getOps<IREE::VM::ModuleOp>()) {
+ outerImportModuleOp.getOps<IREE::VM::ModuleOp>()) {
if (failed(appendImportModule(importModuleOp, innerModuleOp))) {
importModuleOp.emitError() << "failed to import module";
return signalPassFailure();
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp b/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp
index f79ba02..529cc5b 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp
+++ b/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp
@@ -35,7 +35,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(
StringRef(vmla_imports_create()->data, vmla_imports_create()->size),
getDialect()->getContext());
diff --git a/iree/samples/custom_modules/dialect/custom_dialect.cc b/iree/samples/custom_modules/dialect/custom_dialect.cc
index 17f487f..0b48e8e 100644
--- a/iree/samples/custom_modules/dialect/custom_dialect.cc
+++ b/iree/samples/custom_modules/dialect/custom_dialect.cc
@@ -52,7 +52,7 @@
public:
using VMConversionDialectInterface::VMConversionDialectInterface;
- OwningModuleRef getVMImportModule() const override {
+ OwningModuleRef parseVMImportModule() const override {
return mlir::parseSourceString(
StringRef(custom_imports_create()->data, custom_imports_create()->size),
getDialect()->getContext());