Adding a dialect interface to allow decoupling of dialect->VM conversion.
This allows any dialect (such as custom ones) to opt into conversion during the big VM conversion pass.

PiperOrigin-RevId: 285014909
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
index 8c23b95..7e8db9a 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
@@ -33,6 +33,7 @@
         "//iree/compiler/Dialect",
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/IR",
+        "//iree/compiler/Dialect/HAL/IR:HALDialect",
         "//iree/compiler/Dialect/HAL/Target:ExecutableTarget",
         "//iree/compiler/Dialect/HAL/Utils",
         "//iree/compiler/Dialect/VM/IR",
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
index aef21a0..da8dca1 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
@@ -35,12 +35,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-LogicalResult appendHALImportModule(mlir::ModuleOp moduleOp) {
-  return appendImportModule(
-      StringRef(hal_imports_create()->data, hal_imports_create()->size),
-      moduleOp);
-}
-
 extern void populateHALAllocatorToVMPatterns(
     MLIRContext *context, SymbolTable &importSymbols,
     TypeConverter &typeConverter, OwningRewritePatternList &patterns);
@@ -100,7 +94,9 @@
     std::tie(outerModuleOp, innerModuleOp) =
         VMConversionTarget::nestModuleForConversion(getModule());
 
-    appendHALImportModule(innerModuleOp);
+    appendImportModule(
+        StringRef(hal_imports_create()->data, hal_imports_create()->size),
+        innerModuleOp);
 
     OwningRewritePatternList conversionPatterns;
     populateStandardToVMPatterns(context, conversionPatterns);
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h
index e4ff7ed..9d19120 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h
@@ -22,10 +22,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-// Appends the HAL import module containing the vm.import ops for all HAL
-// methods.
-LogicalResult appendHALImportModule(mlir::ModuleOp moduleOp);
-
 // Populates conversion patterns from the HAL dialect to the VM dialect.
 void populateHALToVMPatterns(MLIRContext *context, SymbolTable &importSymbols,
                              OwningRewritePatternList &patterns,
diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD
index 9009164..0cb74f5 100644
--- a/iree/compiler/Dialect/HAL/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/IR/BUILD
@@ -29,22 +29,22 @@
 cc_library(
     name = "IR",
     srcs = [
-        "HALDialect.cpp",
         "HALEnums.cpp.inc",
         "HALOpFolders.cpp",
         "HALOpInterface.cpp.inc",
         "HALOps.cpp",
-        "HALOps.cpp.inc",
         "HALTypes.cpp",
     ],
     hdrs = [
-        "HALDialect.h",
         "HALEnums.h.inc",
         "HALOpInterface.h.inc",
         "HALOps.h",
         "HALOps.h.inc",
         "HALTypes.h",
     ],
+    textual_hdrs = [
+        "HALOps.cpp.inc",
+    ],
     deps = [
         ":HALEnumsGen",
         ":HALOpInterfaceGen",
@@ -52,6 +52,26 @@
         "//iree/compiler/Dialect",
         "@llvm//:support",
         "@local_config_mlir//:IR",
+        "@local_config_mlir//:Parser",
+        "@local_config_mlir//:StandardOps",
+        "@local_config_mlir//:Support",
+        "@local_config_mlir//:TransformUtils",
+    ],
+)
+
+cc_library(
+    name = "HALDialect",
+    srcs = ["HALDialect.cpp"],
+    hdrs = ["HALDialect.h"],
+    deps = [
+        ":IR",
+        "//iree/compiler/Dialect",
+        "//iree/compiler/Dialect/HAL:hal_imports",
+        "//iree/compiler/Dialect/HAL/Conversion/HALToVM",
+        "//iree/compiler/Dialect/VM/Conversion",
+        "@llvm//:support",
+        "@local_config_mlir//:IR",
+        "@local_config_mlir//:Parser",
         "@local_config_mlir//:StandardOps",
         "@local_config_mlir//:Support",
         "@local_config_mlir//:TransformUtils",
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index 7408d54..f534e7a 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -14,23 +14,49 @@
 
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 
+#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/HAL/hal.imports.h"
+#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
 #include "llvm/Support/SourceMgr.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser.h"
 
 namespace mlir {
 namespace iree_compiler {
 namespace IREE {
 namespace HAL {
 
-#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc"
+namespace {
 
 static DialectRegistration<HALDialect> hal_dialect;
 
+class HALToVMConversionInterface : public VMConversionDialectInterface {
+ public:
+  using VMConversionDialectInterface::VMConversionDialectInterface;
+
+  OwningModuleRef getVMImportModule() const override {
+    return mlir::parseSourceString(
+        StringRef(hal_imports_create()->data, hal_imports_create()->size),
+        getDialect()->getContext());
+  }
+
+  void populateVMConversionPatterns(
+      SymbolTable &importSymbols, OwningRewritePatternList &patterns,
+      TypeConverter &typeConverter) const override {
+    populateHALToVMPatterns(getDialect()->getContext(), importSymbols, patterns,
+                            typeConverter);
+  }
+};
+
+}  // namespace
+
 HALDialect::HALDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
+  addInterfaces<HALToVMConversionInterface>();
+
   addTypes<AllocatorType, BufferType, CommandBufferType, DescriptorSetType,
            DescriptorSetLayoutType, DeviceType, EventType, ExecutableType,
            ExecutableCacheType, FenceType, RingBufferType, SemaphoreType>();
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.h b/iree/compiler/Dialect/HAL/IR/HALDialect.h
index dc30a49..3f67f42 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.h
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.h
@@ -23,8 +23,6 @@
 namespace IREE {
 namespace HAL {
 
-#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.h.inc"
-
 class HALDialect : public Dialect {
  public:
   explicit HALDialect(MLIRContext *context);
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index b899740..c9f36e5 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -12,7 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "llvm/ADT/StringExtras.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.h b/iree/compiler/Dialect/HAL/IR/HALOps.h
index a669c5b..afca9b0 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.h
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -17,7 +17,6 @@
 
 #include <cstdint>
 
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Traits.h"
 #include "mlir/IR/Attributes.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index def9da6..70aed7f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -24,7 +24,7 @@
 namespace IREE {
 namespace HAL {
 
-// TODO(benvanik): struct types.
+#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc"
 
 }  // namespace HAL
 }  // namespace IREE
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index a6139b8..3fec3d3 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -23,6 +23,7 @@
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/IR/Types.h"
@@ -36,6 +37,8 @@
 namespace IREE {
 namespace HAL {
 
+#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.h.inc"
+
 //===----------------------------------------------------------------------===//
 // RefObject types
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VM/Conversion/BUILD b/iree/compiler/Dialect/VM/Conversion/BUILD
index 1abc463..12762a0 100644
--- a/iree/compiler/Dialect/VM/Conversion/BUILD
+++ b/iree/compiler/Dialect/VM/Conversion/BUILD
@@ -25,6 +25,7 @@
         "TypeConverter.cpp",
     ],
     hdrs = [
+        "ConversionDialectInterface.h",
         "ConversionTarget.h",
         "ImportUtils.h",
         "TypeConverter.h",
diff --git a/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
new file mode 100644
index 0000000..537b228
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
@@ -0,0 +1,53 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
+#define IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// An interface for dialects to expose VM conversion functionality.
+// The VM conversion pass will query used dialects via this interface to find
+// import definitions and conversion patterns that map from the source dialect
+// to the VM dialect.
+class VMConversionDialectInterface
+    : public DialectInterface::Base<VMConversionDialectInterface> {
+ public:
+  VMConversionDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+  // Returns a module containing one or more vm.modules with vm.import ops.
+  // These modules will be merged into the module being compiled to provide
+  // import definitions to the conversion and lowering process.
+  virtual OwningModuleRef getVMImportModule() const = 0;
+
+  // Populates |patterns| with rewrites that convert from the implementation
+  // dialect to the VM dialect. Many of these can just be default conversions
+  // via the VMImportOpConversion class.
+  //
+  // |importSymbols| contains all vm.imports that have been queried from all
+  // used dialects, not just this dialect.
+  virtual void populateVMConversionPatterns(
+      SymbolTable &importSymbols, OwningRewritePatternList &patterns,
+      TypeConverter &typeConverter) const = 0;
+};
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index c635126..e76d67c 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -28,7 +28,6 @@
         "Passes.h",
     ],
     deps = [
-        "//iree/compiler/Dialect/HAL/Conversion/HALToVM",
         "//iree/compiler/Dialect/VM/Conversion",
         "//iree/compiler/Dialect/VM/Conversion/StandardToVM",
         "//iree/compiler/Dialect/VM/IR",
diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 0ed0b71..14cc7bd 100644
--- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -14,10 +14,12 @@
 
 #include <tuple>
 
-#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h"
+#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
 #include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
 #include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
 #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
+#include "llvm/ADT/STLExtras.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -29,7 +31,29 @@
 namespace IREE {
 namespace VM {
 
-// TODO(benvanik): import dialect registration.
+// Returns a stably sorted list of dialect interfaces of T for all dialects used
+// within the given module.
+template <typename T>
+SmallVector<const T *, 4> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
+  SmallPtrSet<const T *, 4> resultSet;
+  moduleOp.walk([&](Operation *op) {
+    auto *dialect = op->getDialect();
+    if (!dialect) return;
+    auto *dialectInterface = dialect->getRegisteredInterface<T>();
+    if (!dialectInterface) return;
+    resultSet.insert(dialectInterface);
+  });
+
+  // NOTE: to ensure deterministic output we sort the result so that imports are
+  // always added in a consistent order.
+  SmallVector<const T *, 4> results = {resultSet.begin(), resultSet.end()};
+  llvm::sort(
+      results, +[](const T *a, const T *b) {
+        return a->getDialect()->getNamespace().compare(
+                   b->getDialect()->getNamespace()) < 0;
+      });
+  return results;
+}
 
 // Runs conversion with registered input dialects.
 class ConversionPass : public OperationPass<ConversionPass, mlir::ModuleOp> {
@@ -43,16 +67,32 @@
     std::tie(outerModuleOp, innerModuleOp) =
         VMConversionTarget::nestModuleForConversion(getOperation());
 
-    // TODO(benvanik): registration system for custom dialects.
-    appendHALImportModule(innerModuleOp);
+    // Append all vm.import ops from used dialects so that we can look them up
+    // during conversion.
+    auto usedDialects =
+        gatherUsedDialectInterfaces<VMConversionDialectInterface>(
+            innerModuleOp);
+    for (auto *dialectInterface : usedDialects) {
+      auto outerImportModuleOp = dialectInterface->getVMImportModule();
+      for (auto importModuleOp :
+           outerImportModuleOp->getOps<IREE::VM::ModuleOp>()) {
+        if (failed(appendImportModule(importModuleOp, innerModuleOp))) {
+          importModuleOp.emitError() << "failed to import module";
+          return signalPassFailure();
+        }
+      }
+    }
 
     OwningRewritePatternList conversionPatterns;
     populateStandardToVMPatterns(context, conversionPatterns);
 
-    // TODO(benvanik): registration system for custom dialects.
+    // Populate patterns from all used dialects, providing the imports they
+    // registered earlier.
     SymbolTable importSymbols(innerModuleOp);
-    populateHALToVMPatterns(context, importSymbols, conversionPatterns,
-                            typeConverter);
+    for (auto *dialectInterface : usedDialects) {
+      dialectInterface->populateVMConversionPatterns(
+          importSymbols, conversionPatterns, typeConverter);
+    }
 
     if (failed(applyFullConversion(outerModuleOp, conversionTarget,
                                    conversionPatterns, &typeConverter))) {