Adding a dialect interface to allow decoupling of dialect->VM conversion. This allows any dialect (such as custom ones) to opt into conversion during the big VM conversion pass. PiperOrigin-RevId: 285014909
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD index 8c23b95..7e8db9a 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
@@ -33,6 +33,7 @@ "//iree/compiler/Dialect", "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/IR", + "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/HAL/Target:ExecutableTarget", "//iree/compiler/Dialect/HAL/Utils", "//iree/compiler/Dialect/VM/IR",
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp index aef21a0..da8dca1 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
@@ -35,12 +35,6 @@ namespace mlir { namespace iree_compiler { -LogicalResult appendHALImportModule(mlir::ModuleOp moduleOp) { - return appendImportModule( - StringRef(hal_imports_create()->data, hal_imports_create()->size), - moduleOp); -} - extern void populateHALAllocatorToVMPatterns( MLIRContext *context, SymbolTable &importSymbols, TypeConverter &typeConverter, OwningRewritePatternList &patterns); @@ -100,7 +94,9 @@ std::tie(outerModuleOp, innerModuleOp) = VMConversionTarget::nestModuleForConversion(getModule()); - appendHALImportModule(innerModuleOp); + appendImportModule( + StringRef(hal_imports_create()->data, hal_imports_create()->size), + innerModuleOp); OwningRewritePatternList conversionPatterns; populateStandardToVMPatterns(context, conversionPatterns);
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h index e4ff7ed..9d19120 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h
@@ -22,10 +22,6 @@ namespace mlir { namespace iree_compiler { -// Appends the HAL import module containing the vm.import ops for all HAL -// methods. -LogicalResult appendHALImportModule(mlir::ModuleOp moduleOp); - // Populates conversion patterns from the HAL dialect to the VM dialect. void populateHALToVMPatterns(MLIRContext *context, SymbolTable &importSymbols, OwningRewritePatternList &patterns,
diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD index 9009164..0cb74f5 100644 --- a/iree/compiler/Dialect/HAL/IR/BUILD +++ b/iree/compiler/Dialect/HAL/IR/BUILD
@@ -29,22 +29,22 @@ cc_library( name = "IR", srcs = [ - "HALDialect.cpp", "HALEnums.cpp.inc", "HALOpFolders.cpp", "HALOpInterface.cpp.inc", "HALOps.cpp", - "HALOps.cpp.inc", "HALTypes.cpp", ], hdrs = [ - "HALDialect.h", "HALEnums.h.inc", "HALOpInterface.h.inc", "HALOps.h", "HALOps.h.inc", "HALTypes.h", ], + textual_hdrs = [ + "HALOps.cpp.inc", + ], deps = [ ":HALEnumsGen", ":HALOpInterfaceGen", @@ -52,6 +52,26 @@ "//iree/compiler/Dialect", "@llvm//:support", "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + ], +) + +cc_library( + name = "HALDialect", + srcs = ["HALDialect.cpp"], + hdrs = ["HALDialect.h"], + deps = [ + ":IR", + "//iree/compiler/Dialect", + "//iree/compiler/Dialect/HAL:hal_imports", + "//iree/compiler/Dialect/HAL/Conversion/HALToVM", + "//iree/compiler/Dialect/VM/Conversion", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", "@local_config_mlir//:TransformUtils",
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index 7408d54..f534e7a 100644 --- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -14,23 +14,49 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/HAL/hal.imports.h" +#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "llvm/Support/SourceMgr.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Parser.h" namespace mlir { namespace iree_compiler { namespace IREE { namespace HAL { -#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc" +namespace { static DialectRegistration<HALDialect> hal_dialect; +class HALToVMConversionInterface : public VMConversionDialectInterface { + public: + using VMConversionDialectInterface::VMConversionDialectInterface; + + OwningModuleRef getVMImportModule() const override { + return mlir::parseSourceString( + StringRef(hal_imports_create()->data, hal_imports_create()->size), + getDialect()->getContext()); + } + + void populateVMConversionPatterns( + SymbolTable &importSymbols, OwningRewritePatternList &patterns, + TypeConverter &typeConverter) const override { + populateHALToVMPatterns(getDialect()->getContext(), importSymbols, patterns, + typeConverter); + } +}; + +} // namespace + HALDialect::HALDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { + addInterfaces<HALToVMConversionInterface>(); + addTypes<AllocatorType, BufferType, CommandBufferType, DescriptorSetType, DescriptorSetLayoutType, DeviceType, EventType, ExecutableType, ExecutableCacheType, FenceType, RingBufferType, SemaphoreType>();
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.h b/iree/compiler/Dialect/HAL/IR/HALDialect.h index dc30a49..3f67f42 100644 --- a/iree/compiler/Dialect/HAL/IR/HALDialect.h +++ b/iree/compiler/Dialect/HAL/IR/HALDialect.h
@@ -23,8 +23,6 @@ namespace IREE { namespace HAL { -#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.h.inc" - class HALDialect : public Dialect { public: explicit HALDialect(MLIRContext *context);
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index b899740..c9f36e5 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/Ops.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.h b/iree/compiler/Dialect/HAL/IR/HALOps.h index a669c5b..afca9b0 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.h +++ b/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -17,7 +17,6 @@ #include <cstdint> -#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/Traits.h" #include "mlir/IR/Attributes.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index def9da6..70aed7f 100644 --- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -24,7 +24,7 @@ namespace IREE { namespace HAL { -// TODO(benvanik): struct types. +#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc" } // namespace HAL } // namespace IREE
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h index a6139b8..3fec3d3 100644 --- a/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -23,6 +23,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -36,6 +37,8 @@ namespace IREE { namespace HAL { +#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.h.inc" + //===----------------------------------------------------------------------===// // RefObject types //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VM/Conversion/BUILD b/iree/compiler/Dialect/VM/Conversion/BUILD index 1abc463..12762a0 100644 --- a/iree/compiler/Dialect/VM/Conversion/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/BUILD
@@ -25,6 +25,7 @@ "TypeConverter.cpp", ], hdrs = [ + "ConversionDialectInterface.h", "ConversionTarget.h", "ImportUtils.h", "TypeConverter.h",
diff --git a/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h new file mode 100644 index 0000000..537b228 --- /dev/null +++ b/iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_ +#define IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_ + +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Module.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// An interface for dialects to expose VM conversion functionality. +// The VM conversion pass will query used dialects via this interface to find +// import definitions and conversion patterns that map from the source dialect +// to the VM dialect. +class VMConversionDialectInterface + : public DialectInterface::Base<VMConversionDialectInterface> { + public: + VMConversionDialectInterface(Dialect *dialect) : Base(dialect) {} + + // Returns a module containing one or more vm.modules with vm.import ops. + // These modules will be merged into the module being compiled to provide + // import definitions to the conversion and lowering process. + virtual OwningModuleRef getVMImportModule() const = 0; + + // Populates |patterns| with rewrites that convert from the implementation + // dialect to the VM dialect. Many of these can just be default conversions + // via the VMImportOpConversion class. + // + // |importSymbols| contains all vm.imports that have been queried from all + // used dialects, not just this dialect. + virtual void populateVMConversionPatterns( + SymbolTable &importSymbols, OwningRewritePatternList &patterns, + TypeConverter &typeConverter) const = 0; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_VM_CONVERSION_CONVERSIONDIALECTINTERFACE_H_
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD index c635126..e76d67c 100644 --- a/iree/compiler/Dialect/VM/Transforms/BUILD +++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -28,7 +28,6 @@ "Passes.h", ], deps = [ - "//iree/compiler/Dialect/HAL/Conversion/HALToVM", "//iree/compiler/Dialect/VM/Conversion", "//iree/compiler/Dialect/VM/Conversion/StandardToVM", "//iree/compiler/Dialect/VM/IR",
diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 0ed0b71..14cc7bd 100644 --- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -14,10 +14,12 @@ #include <tuple> -#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h" +#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h" #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -29,7 +31,29 @@ namespace IREE { namespace VM { -// TODO(benvanik): import dialect registration. +// Returns a stably sorted list of dialect interfaces of T for all dialects used +// within the given module. +template <typename T> +SmallVector<const T *, 4> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { + SmallPtrSet<const T *, 4> resultSet; + moduleOp.walk([&](Operation *op) { + auto *dialect = op->getDialect(); + if (!dialect) return; + auto *dialectInterface = dialect->getRegisteredInterface<T>(); + if (!dialectInterface) return; + resultSet.insert(dialectInterface); + }); + + // NOTE: to ensure deterministic output we sort the result so that imports are + // always added in a consistent order. + SmallVector<const T *, 4> results = {resultSet.begin(), resultSet.end()}; + llvm::sort( + results, +[](const T *a, const T *b) { + return a->getDialect()->getNamespace().compare( + b->getDialect()->getNamespace()) < 0; + }); + return results; +} // Runs conversion with registered input dialects. class ConversionPass : public OperationPass<ConversionPass, mlir::ModuleOp> { @@ -43,16 +67,32 @@ std::tie(outerModuleOp, innerModuleOp) = VMConversionTarget::nestModuleForConversion(getOperation()); - // TODO(benvanik): registration system for custom dialects. - appendHALImportModule(innerModuleOp); + // Append all vm.import ops from used dialects so that we can look them up + // during conversion. + auto usedDialects = + gatherUsedDialectInterfaces<VMConversionDialectInterface>( + innerModuleOp); + for (auto *dialectInterface : usedDialects) { + auto outerImportModuleOp = dialectInterface->getVMImportModule(); + for (auto importModuleOp : + outerImportModuleOp->getOps<IREE::VM::ModuleOp>()) { + if (failed(appendImportModule(importModuleOp, innerModuleOp))) { + importModuleOp.emitError() << "failed to import module"; + return signalPassFailure(); + } + } + } OwningRewritePatternList conversionPatterns; populateStandardToVMPatterns(context, conversionPatterns); - // TODO(benvanik): registration system for custom dialects. + // Populate patterns from all used dialects, providing the imports they + // registered earlier. SymbolTable importSymbols(innerModuleOp); - populateHALToVMPatterns(context, importSymbols, conversionPatterns, - typeConverter); + for (auto *dialectInterface : usedDialects) { + dialectInterface->populateVMConversionPatterns( + importSymbols, conversionPatterns, typeConverter); + } if (failed(applyFullConversion(outerModuleOp, conversionTarget, conversionPatterns, &typeConverter))) {