Adding hal_inline dialect and runtime module. This lowers from the stream dialect into a much reduced form of the HAL dialect that uses a compatible type system with the HAL dialect but a restricted synchronous/local execution model. Executables translated to `vmvx-inline` are inlined directly into the host module and the only thing remaining is `!hal.buffer`/`!hal.buffer_view` management for ABI compatibility with the full HAL dialect. The tradeoff here with the full HAL dialect is that this only runs in-process and synchronously on the VM (bytecode or emitc) and is not relevant to CUDA/multithreaded CPU/etc. For a single-core embedded device with VMVX kernels, though, it should be more than enough to run all models.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index d06a3b5..4230e75 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -362,6 +362,10 @@ OpPassManager &nestedModulePM = passManager.nest<ModuleOp>(); addBufferizePasses(nestedModulePM); + // Cleanup the IR that may now have unused loops. + nestedModulePM.addNestedPass<func::FuncOp>( + createRemoveSingleIterationLoopPass()); + // Convert buffer-level microkernels. if (clEnableMicrokernels) { nestedModulePM.addNestedPass<func::FuncOp>(
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/BUILD new file mode 100644 index 0000000..cadf491 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/BUILD
@@ -0,0 +1,11 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/CMakeLists.txt new file mode 100644 index 0000000..00e2756 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/CMakeLists.txt
@@ -0,0 +1,13 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/BUILD new file mode 100644 index 0000000..63eef08 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/BUILD
@@ -0,0 +1,22 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/embed_data:build_defs.bzl", "c_embed_data") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +c_embed_data( + name = "hal_inline_imports", + srcs = ["hal_inline.imports.mlir"], + c_file_output = "hal_inline.imports.c", + flatten = True, + h_file_output = "hal_inline.imports.h", + identifier = "iree_hal_inline_imports", +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/CMakeLists.txt new file mode 100644 index 0000000..adaedc2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/CMakeLists.txt
@@ -0,0 +1,28 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_c_embed_data( + NAME + hal_inline_imports + SRCS + "hal_inline.imports.mlir" + C_FILE_OUTPUT + "hal_inline.imports.c" + H_FILE_OUTPUT + "hal_inline.imports.h" + IDENTIFIER + "iree_hal_inline_imports" + FLATTEN + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/BUILD new file mode 100644 index 0000000..cadf491 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/BUILD
@@ -0,0 +1,11 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/CMakeLists.txt new file mode 100644 index 0000000..9f68086 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/CMakeLists.txt
@@ -0,0 +1,13 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/BUILD new file mode 100644 index 0000000..9cfed7d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/BUILD
@@ -0,0 +1,35 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "HALInlineToVM", + srcs = [ + "ConvertHALInlineToVM.cpp", + ], + hdrs = [ + "ConvertHALInlineToVM.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/VM/Conversion", + "//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM", + "//compiler/src/iree/compiler/Dialect/VM/IR", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/CMakeLists.txt new file mode 100644 index 0000000..681218c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/CMakeLists.txt
@@ -0,0 +1,34 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + HALInlineToVM + HDRS + "ConvertHALInlineToVM.h" + SRCS + "ConvertHALInlineToVM.cpp" + DEPS + MLIRArithmeticDialect + MLIRFuncDialect + MLIRIR + MLIRPass + MLIRTransforms + iree::compiler::Dialect::Modules::HAL::Inline::IR + iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::VM::Conversion + iree::compiler::Dialect::VM::Conversion::StandardToVM + iree::compiler::Dialect::VM::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.cpp new file mode 100644 index 0000000..095398e --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.cpp
@@ -0,0 +1,71 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.h" + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" +#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +void populateHALInlineToVMPatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + SymbolTable &importSymbols, + RewritePatternSet &patterns) { + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferAllocateOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer.allocate"); + patterns.insert< + VMImportOpConversion<IREE::HAL::Inline::BufferAllocateInitializedOp>>( + context, importSymbols, typeConverter, + "hal_inline.buffer.allocate.initialized"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferWrapOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer.wrap"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferSubspanOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer.subspan"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferLengthOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer.length"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferStorageOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer.storage"); + + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewCreateOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.create"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewAssertOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.assert"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewBufferOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.buffer"); + patterns + .insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewElementTypeOp>>( + context, importSymbols, typeConverter, + "hal_inline.buffer_view.element_type"); + patterns.insert< + VMImportOpConversion<IREE::HAL::Inline::BufferViewEncodingTypeOp>>( + context, importSymbols, typeConverter, + "hal_inline.buffer_view.encoding_type"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewRankOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.rank"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewDimOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.dim"); + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::BufferViewTraceOp>>( + context, importSymbols, typeConverter, "hal_inline.buffer_view.trace"); + + patterns.insert<VMImportOpConversion<IREE::HAL::Inline::DeviceQueryOp>>( + context, importSymbols, typeConverter, "hal_inline.device.query.i64"); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.h new file mode 100644 index 0000000..5c155c4 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.h
@@ -0,0 +1,27 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALINLINE_CONVERTHALINLINETOVM_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALINLINE_CONVERTHALINLINETOVM_H_ + +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// Populates conversion patterns from the hal_inline dialect to the VM dialect. +void populateHALInlineToVMPatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + SymbolTable &importSymbols, + RewritePatternSet &patterns); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALINLINE_CONVERTHALINLINETOVM_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD new file mode 100644 index 0000000..b027b7c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/CMakeLists.txt new file mode 100644 index 0000000..53c3829 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/CMakeLists.txt
@@ -0,0 +1,21 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD new file mode 100644 index 0000000..6791d99 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD
@@ -0,0 +1,39 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "HALToHALInline", + srcs = [ + "ConvertHALToHALInline.cpp", + ], + hdrs = [ + "ConvertHALToHALInline.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", + "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Utils", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR:HALInlineDialect", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt new file mode 100644 index 0000000..f8d8546 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt
@@ -0,0 +1,38 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + HALToHALInline + HDRS + "ConvertHALToHALInline.h" + SRCS + "ConvertHALToHALInline.cpp" + DEPS + LLVMSupport + MLIRArithmeticDialect + MLIRFuncDialect + MLIRIR + MLIRPass + MLIRTransforms + iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::HAL::IR::HALDialect + iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Utils + iree::compiler::Dialect::Modules::HAL::Inline::IR + iree::compiler::Dialect::Modules::HAL::Inline::IR::HALInlineDialect + iree::compiler::Dialect::Util::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp new file mode 100644 index 0000000..4c53ae9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp
@@ -0,0 +1,227 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.h" + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct BufferSubspanOpPattern + : public OpConversionPattern<IREE::HAL::BufferSubspanOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferSubspanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto bufferType = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferSubspanOp>( + op, bufferType, adaptor.getSourceBuffer(), adaptor.getSourceOffset(), + adaptor.getLength()); + return success(); + } +}; + +struct BufferLengthOpPattern + : public OpConversionPattern<IREE::HAL::BufferLengthOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferLengthOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sizeType = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferLengthOp>( + op, sizeType, adaptor.getBuffer()); + return success(); + } +}; + +struct BufferLoadOpPattern + : public OpConversionPattern<IREE::HAL::BufferLoadOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value storageBuffer = + rewriter.createOrFold<IREE::HAL::Inline::BufferStorageOp>( + op.getLoc(), adaptor.getSourceBuffer()); + Value storageSize = rewriter.create<IREE::HAL::Inline::BufferLengthOp>( + op.getLoc(), adaptor.getSourceBuffer()); + auto loadType = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>( + op, loadType, storageBuffer, storageSize, adaptor.getSourceOffset()); + return success(); + } +}; + +struct BufferStoreOpPattern + : public OpConversionPattern<IREE::HAL::BufferStoreOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value storageBuffer = + rewriter.createOrFold<IREE::HAL::Inline::BufferStorageOp>( + op.getLoc(), adaptor.getTargetBuffer()); + Value storageSize = rewriter.create<IREE::HAL::Inline::BufferLengthOp>( + op.getLoc(), adaptor.getTargetBuffer()); + rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>( + op, adaptor.getValue(), storageBuffer, storageSize, + adaptor.getTargetOffset()); + return success(); + } +}; + +struct BufferViewCreateOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewCreateOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewCreateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewCreateOp>( + op, adaptor.getBuffer(), adaptor.getElementType(), + adaptor.getEncodingType(), adaptor.getShape()); + return success(); + } +}; + +struct BufferViewBufferOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewBufferOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewBufferOp>( + op, rewriter.getType<IREE::HAL::BufferType>(), adaptor.getBufferView()); + return success(); + } +}; + +struct BufferViewAssertOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewAssertOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewAssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewAssertOp>( + op, adaptor.getBufferView(), adaptor.getMessage(), + adaptor.getElementType(), adaptor.getEncodingType(), + adaptor.getShape()); + return success(); + } +}; + +struct BufferViewElementTypeOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewElementTypeOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewElementTypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewElementTypeOp>( + op, op.getResult().getType(), adaptor.getBufferView()); + return success(); + } +}; + +struct BufferViewEncodingTypeOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewEncodingTypeOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewEncodingTypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewEncodingTypeOp>( + op, op.getResult().getType(), adaptor.getBufferView()); + return success(); + } +}; + +struct BufferViewRankOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewRankOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewRankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewRankOp>( + op, op.getResult().getType(), adaptor.getBufferView()); + return success(); + } +}; + +struct BufferViewDimOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewDimOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewDimOp>( + op, op.getResult().getType(), adaptor.getBufferView(), + adaptor.getIndexAttr()); + return success(); + } +}; + +struct BufferViewTraceOpPattern + : public OpConversionPattern<IREE::HAL::BufferViewTraceOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::HAL::BufferViewTraceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewTraceOp>( + op, adaptor.getKeyAttr(), adaptor.getOperands()); + return success(); + } +}; + +} // namespace + +void populateHALToHALInlinePatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns) { + typeConverter.addConversion([](IREE::HAL::BufferType type) { return type; }); + typeConverter.addConversion( + [](IREE::HAL::BufferViewType type) { return type; }); + + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IREE::Util::BufferType type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + if (inputs[0].getType().isa<IREE::HAL::BufferType>()) { + return builder.createOrFold<IREE::HAL::Inline::BufferStorageOp>( + loc, inputs[0]); + } else { + emitError(loc) << "unsupported HAL inline target materialization: " + << inputs[0].getType(); + return nullptr; + } + }); + + patterns.insert<BufferSubspanOpPattern>(typeConverter, context); + patterns.insert<BufferLengthOpPattern>(typeConverter, context); + patterns.insert<BufferLoadOpPattern>(typeConverter, context); + patterns.insert<BufferStoreOpPattern>(typeConverter, context); + + patterns.insert<BufferViewCreateOpPattern>(typeConverter, context); + patterns.insert<BufferViewAssertOpPattern>(typeConverter, context); + patterns.insert<BufferViewBufferOpPattern>(typeConverter, context); + patterns.insert<BufferViewElementTypeOpPattern>(typeConverter, context); + patterns.insert<BufferViewEncodingTypeOpPattern>(typeConverter, context); + patterns.insert<BufferViewRankOpPattern>(typeConverter, context); + patterns.insert<BufferViewDimOpPattern>(typeConverter, context); + patterns.insert<BufferViewTraceOpPattern>(typeConverter, context); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.h new file mode 100644 index 0000000..19c402f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.h
@@ -0,0 +1,25 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALTOHALINLINE_CONVERTHALTOHALINLINE_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALTOHALINLINE_CONVERTHALTOHALINLINE_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// Populates conversion patterns for full HAL -> inline HAL. +void populateHALToHALInlinePatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_HALTOHALINLINE_CONVERTHALTOHALINLINE_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD new file mode 100644 index 0000000..c3b3094 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD
@@ -0,0 +1,29 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "buffer_ops.mlir", + "buffer_view_ops.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/CMakeLists.txt new file mode 100644 index 0000000..fa35f20 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/CMakeLists.txt
@@ -0,0 +1,24 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "buffer_ops.mlir" + "buffer_view_ops.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir new file mode 100644 index 0000000..b6ce93a --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir
@@ -0,0 +1,53 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// CHECK-LABEL: @buffer_subspan +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer) +func.func @buffer_subspan(%buffer: !hal.buffer) -> !hal.buffer { + // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 + %offset = arith.constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 200 + %length = arith.constant 200 : index + // CHECK: %[[SUBSPAN:.+]] = hal_inline.buffer.subspan<%[[BUFFER]] : !hal.buffer>[%[[OFFSET]], %[[LENGTH]]] : !hal.buffer + %subspan = hal.buffer.subspan<%buffer : !hal.buffer>[%offset, %length] : !hal.buffer + // CHECK: return %[[SUBSPAN]] + return %subspan : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @buffer_length +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer) +func.func @buffer_length(%buffer: !hal.buffer) -> index { + // CHECK: hal_inline.buffer.length<%[[BUFFER]] : !hal.buffer> : index + %length = hal.buffer.length<%buffer : !hal.buffer> : index + return %length : index +} + +// ----- + +// CHECK-LABEL: @buffer_load +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer) +func.func @buffer_load(%buffer: !hal.buffer) -> i32 { + // CHECK-DAG: %[[REL_OFFSET:.+]] = arith.constant 100 + %rel_offset = arith.constant 100 : index + // CHECK-DAG: %[[STORAGE:.+]] = hal_inline.buffer.storage<%[[BUFFER:.+]] : !hal.buffer> : !util.buffer + // CHECK-DAG: %[[LENGTH:.+]] = hal_inline.buffer.length<%[[BUFFER]] : !hal.buffer> : index + // CHECK: %[[VALUE:.+]] = util.buffer.load %[[STORAGE]][%[[REL_OFFSET]]] : !util.buffer{%[[LENGTH]]} -> i32 + %value = hal.buffer.load<%buffer : !hal.buffer>[%rel_offset] : i32 + // CHECK-NEXT: return %[[VALUE]] + return %value : i32 +} + +// ----- + +// CHECK-LABEL: @buffer_store +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer, %[[VALUE:.+]]: i32) +func.func @buffer_store(%buffer: !hal.buffer, %value: i32) { + // CHECK-DAG: %[[REL_OFFSET:.+]] = arith.constant 100 + %rel_offset = arith.constant 100 : index + // CHECK-DAG: %[[STORAGE:.+]] = hal_inline.buffer.storage<%[[BUFFER:.+]] : !hal.buffer> : !util.buffer + // CHECK-DAG: %[[LENGTH:.+]] = hal_inline.buffer.length<%[[BUFFER]] : !hal.buffer> : index + // CHECK: util.buffer.store %[[VALUE]], %[[STORAGE]][%[[REL_OFFSET]]] : i32 -> !util.buffer{%[[LENGTH]]} + hal.buffer.store<%buffer : !hal.buffer>[%rel_offset] value(%value : i32) + return +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir new file mode 100644 index 0000000..c65b1c0 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir
@@ -0,0 +1,37 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// CHECK-LABEL: @buffer_view_create +func.func @buffer_view_create(%arg0: !hal.buffer, %arg1: index, %arg2: index) -> !hal.buffer_view { + %c1 = arith.constant 1 : i32 + %c32 = arith.constant 32 : i32 + // CHECK: %view = hal_inline.buffer_view.create + // CHECK-SAME: buffer(%arg0 : !hal.buffer) + // CHECK-SAME: shape([%arg1, %arg2]) + // CHECK-SAME: type(%c32_i32) + // CHECK-SAME: encoding(%c1_i32) : !hal.buffer_view + %view = hal.buffer_view.create buffer(%arg0 : !hal.buffer) + shape([%arg1, %arg2]) + type(%c32) + encoding(%c1) : !hal.buffer_view + return %view : !hal.buffer_view +} + +// ----- + +// CHECK-LABEL: @buffer_view_buffer +func.func @buffer_view_buffer(%arg0: !hal.buffer_view) -> !hal.buffer { + // CHECK: %buffer = hal_inline.buffer_view.buffer<%arg0 : !hal.buffer_view> : !hal.buffer + %buffer = hal.buffer_view.buffer<%arg0 : !hal.buffer_view> : !hal.buffer + return %buffer : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @buffer_view_shape_queries +func.func @buffer_view_shape_queries(%arg0: !hal.buffer_view) -> (index, index) { + // CHECK: %{{.+}} = hal_inline.buffer_view.rank<%arg0 : !hal.buffer_view> : index + %0 = hal.buffer_view.rank<%arg0 : !hal.buffer_view> : index + // CHECK: %{{.+}} = hal_inline.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index + %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index + return %0, %1 : index, index +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD new file mode 100644 index 0000000..80498a2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD
@@ -0,0 +1,40 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "StreamToHALInline", + srcs = [ + "ConvertStreamToHALInline.cpp", + ], + hdrs = [ + "ConvertStreamToHALInline.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", + "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Utils", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR:HALInlineDialect", + "//compiler/src/iree/compiler/Dialect/Stream/IR", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt new file mode 100644 index 0000000..b688989 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt
@@ -0,0 +1,39 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + StreamToHALInline + HDRS + "ConvertStreamToHALInline.h" + SRCS + "ConvertStreamToHALInline.cpp" + DEPS + LLVMSupport + MLIRArithmeticDialect + MLIRFuncDialect + MLIRIR + MLIRPass + MLIRTransforms + iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::HAL::IR::HALDialect + iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Utils + iree::compiler::Dialect::Modules::HAL::Inline::IR + iree::compiler::Dialect::Modules::HAL::Inline::IR::HALInlineDialect + iree::compiler::Dialect::Stream::IR + iree::compiler::Dialect::Util::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp new file mode 100644 index 0000000..a073c7d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp
@@ -0,0 +1,628 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.h" + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +static Value getResourceSize(Location loc, Value resource, OpBuilder &builder) { + if (resource.getType().isa<IREE::HAL::BufferType>()) { + return builder.createOrFold<IREE::HAL::Inline::BufferLengthOp>( + loc, builder.getIndexType(), resource); + } + return builder.createOrFold<IREE::Util::BufferSizeOp>( + loc, builder.getIndexType(), resource); +} + +struct Storage { + // Underlying storage buffer. + Value buffer; + // Total size of the storage buffer in bytes. + Value bufferSize; +}; + +static Storage getResourceStorage(Location loc, Value resource, + Value resourceSize, OpBuilder &builder) { + if (resource.getType().isa<IREE::HAL::BufferType>()) { + // Get the storage of the buffer; the returned buffer is already a subspan. + auto storageBuffer = + builder.createOrFold<IREE::HAL::Inline::BufferStorageOp>(loc, resource); + auto storageSize = getResourceSize(loc, resource, builder); + return { + storageBuffer, + storageSize, + }; + } + return { + resource, + resourceSize, + }; +} + +struct ResourceAllocOpPattern + : public OpConversionPattern<IREE::Stream::ResourceAllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceAllocOp allocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto deviceBufferType = rewriter.getType<IREE::HAL::BufferType>(); + auto hostBufferType = rewriter.getType<IREE::Util::BufferType>(); + + // For now we don't have this information and assume something conservative. + Value minAlignment = + rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(), 64); + + SmallVector<Value> results; + for (auto [resourceResult, storageSize] : + llvm::zip(allocOp.getResults(), allocOp.getStorageSizes())) { + auto allocateOp = rewriter.create<IREE::HAL::Inline::BufferAllocateOp>( + allocOp.getLoc(), deviceBufferType, hostBufferType, minAlignment, + storageSize); + results.push_back(allocateOp.getResult()); + } + + rewriter.replaceOp(allocOp, results); + return success(); + } +}; + +struct ResourceAllocaOpPattern + : public OpConversionPattern<IREE::Stream::ResourceAllocaOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto deviceBufferType = rewriter.getType<IREE::HAL::BufferType>(); + auto hostBufferType = rewriter.getType<IREE::Util::BufferType>(); + + // For now we don't have this information and assume something conservative. + Value minAlignment = + rewriter.create<arith::ConstantIndexOp>(allocaOp.getLoc(), 64); + auto allocateOp = rewriter.create<IREE::HAL::Inline::BufferAllocateOp>( + allocaOp.getLoc(), deviceBufferType, hostBufferType, minAlignment, + adaptor.getStorageSize()); + + auto resolvedTimepoint = + rewriter.create<arith::ConstantIntOp>(allocaOp.getLoc(), 0, 64) + .getResult(); + + rewriter.replaceOp(allocaOp, {allocateOp.getResult(), resolvedTimepoint}); + return success(); + } +}; + +struct ResourceDeallocaOpPattern + : public OpConversionPattern<IREE::Stream::ResourceDeallocaOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceDeallocaOp deallocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO(benvanik): discard op? + auto resolvedTimepoint = + rewriter.create<arith::ConstantIntOp>(deallocaOp.getLoc(), 0, 64) + .getResult(); + rewriter.replaceOp(deallocaOp, {resolvedTimepoint}); + return success(); + } +}; + +struct ResourceSizeOpPattern + : public OpConversionPattern<IREE::Stream::ResourceSizeOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceSizeOp sizeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(sizeOp, getResourceSize(sizeOp.getLoc(), + adaptor.getOperand(), rewriter)); + return success(); + } +}; + +// The staging buffer returned from this is always a !util.buffer. +// We can thus directly pass along the input buffer that's being mapped +// (after taking a subspan for the defined range). +struct ResourceMapOpPattern + : public OpConversionPattern<IREE::Stream::ResourceMapOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceMapOp mapOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>( + mapOp, adaptor.getSource(), + getResourceSize(mapOp.getLoc(), adaptor.getSource(), rewriter), + adaptor.getSourceOffset(), adaptor.getResultSize()); + return success(); + } +}; + +// The constant buffer returned from this is always a !util.buffer. +// We can thus directly pass along the input buffer that's being mapped +// (after taking a subspan for the defined range). +struct ResourceTryMapOpPattern + : public OpConversionPattern<IREE::Stream::ResourceTryMapOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceTryMapOp tryMapOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value subspan = rewriter.create<IREE::Util::BufferSubspanOp>( + tryMapOp.getLoc(), adaptor.getSource(), + getResourceSize(tryMapOp.getLoc(), adaptor.getSource(), rewriter), + adaptor.getSourceOffset(), adaptor.getResultSize()); + Value didMap = + rewriter.create<arith::ConstantIntOp>(tryMapOp.getLoc(), 1, 1); + rewriter.replaceOp(tryMapOp, {didMap, subspan}); + return success(); + } +}; + +struct ResourceLoadOpPattern + : public OpConversionPattern<IREE::Stream::ResourceLoadOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceLoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto storage = getResourceStorage(loc, adaptor.getSource(), + adaptor.getSourceSize(), rewriter); + auto loadType = + getTypeConverter()->convertType(loadOp.getResult().getType()); + rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>( + loadOp, loadType, storage.buffer, storage.bufferSize, + adaptor.getSourceOffset()); + return success(); + } +}; + +struct ResourceStoreOpPattern + : public OpConversionPattern<IREE::Stream::ResourceStoreOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceStoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto storage = getResourceStorage(loc, adaptor.getTarget(), + adaptor.getTargetSize(), rewriter); + rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>( + storeOp, adaptor.getValue(), storage.buffer, storage.bufferSize, + adaptor.getTargetOffset()); + return success(); + } +}; + +struct ResourceSubviewOpPattern + : public OpConversionPattern<IREE::Stream::ResourceSubviewOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::ResourceSubviewOp subviewOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getSource().getType().isa<IREE::HAL::BufferType>()) { + auto bufferType = rewriter.getType<IREE::HAL::BufferType>(); + // NOTE: this aliases! We assume at this point all useful alias analysis + // has been performed and it's fine to lose the tie information here. + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferSubspanOp>( + subviewOp, bufferType, adaptor.getSource(), adaptor.getSourceOffset(), + adaptor.getResultSize()); + } else { + rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>( + subviewOp, adaptor.getSource(), adaptor.getSourceSize(), + adaptor.getSourceOffset(), adaptor.getResultSize()); + } + return success(); + } +}; + +struct TensorImportBufferOpPattern + : public OpConversionPattern<IREE::Stream::TensorImportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!importOp.getSource().getType().isa<IREE::HAL::BufferType>()) { + return failure(); + } + + // Directly use the buffer. + auto buffer = adaptor.getSource(); + rewriter.replaceOp(importOp, buffer); + return success(); + } +}; + +struct TensorImportBufferViewOpPattern + : public OpConversionPattern<IREE::Stream::TensorImportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = importOp.getSource().getType(); + if (!sourceType.isa<IREE::HAL::BufferViewType>() && + !sourceType.isa<TensorType>()) { + return failure(); + } + + auto bufferView = adaptor.getSource(); + auto bufferType = rewriter.getType<IREE::HAL::BufferType>(); + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewBufferOp>( + importOp, bufferType, bufferView); + return success(); + } +}; + +struct TensorExportBufferOpPattern + : public OpConversionPattern<IREE::Stream::TensorExportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!exportOp.getResult().getType().isa<IREE::HAL::BufferType>()) { + return failure(); + } + rewriter.replaceOp(exportOp, adaptor.getSource()); + return success(); + } +}; + +struct TensorExportBufferViewOpPattern + : public OpConversionPattern<IREE::Stream::TensorExportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto targetType = exportOp.getResult().getType(); + if (!targetType.isa<IREE::HAL::BufferViewType>() && + !targetType.isa<TensorType>()) { + return failure(); + } + + auto loc = exportOp.getLoc(); + auto tensorType = adaptor.getSourceEncoding().cast<RankedTensorType>(); + auto dynamicDims = adaptor.getSourceEncodingDims(); + + // NOTE: we should have verified supported encodings/types at entry into the + // HAL pipeline. + auto encodingType = + IREE::HAL::getEncodingTypeValue(tensorType.getEncoding()); + assert(encodingType.has_value() && "invalid tensor encoding"); + auto elementType = + IREE::HAL::getElementTypeValue(tensorType.getElementType()); + assert(elementType.has_value() && "invalid tensor element type"); + + // Flatten static + dynamic shape dimensions. + SmallVector<Value> dims; + unsigned dynamicIdx = 0; + for (int64_t idx = 0; idx < tensorType.getRank(); ++idx) { + if (tensorType.isDynamicDim(idx)) { + dims.push_back(dynamicDims[dynamicIdx++]); + } else { + dims.push_back(rewriter.create<arith::ConstantIndexOp>( + loc, tensorType.getDimSize(idx))); + } + } + + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewCreateOp>( + exportOp, adaptor.getSource(), elementType.value(), + encodingType.value(), dims); + return success(); + } +}; + +struct TensorTraceOpPattern + : public OpConversionPattern<IREE::Stream::TensorTraceOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TensorTraceOp traceOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewTraceOp>( + traceOp, traceOp.getKeyAttr(), adaptor.getOperands()); + return success(); + } +}; + +struct CmdFlushOpPattern + : public OpConversionPattern<IREE::Stream::CmdFlushOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdFlushOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct CmdInvalidateOpPattern + : public OpConversionPattern<IREE::Stream::CmdInvalidateOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdInvalidateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct CmdDiscardOpPattern + : public OpConversionPattern<IREE::Stream::CmdDiscardOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdDiscardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct CmdFillOpPattern : public OpConversionPattern<IREE::Stream::CmdFillOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdFillOp fillOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = fillOp.getLoc(); + auto storage = getResourceStorage(loc, adaptor.getTarget(), + adaptor.getTargetSize(), rewriter); + rewriter.replaceOpWithNewOp<IREE::Util::BufferFillOp>( + fillOp, adaptor.getValue(), storage.buffer, storage.bufferSize, + adaptor.getTargetOffset(), adaptor.getTargetLength()); + return success(); + } +}; + +struct CmdCopyOpPattern : public OpConversionPattern<IREE::Stream::CmdCopyOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdCopyOp copyOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = copyOp.getLoc(); + auto sourceStorage = getResourceStorage(loc, adaptor.getSource(), + adaptor.getSourceSize(), rewriter); + auto targetStorage = getResourceStorage(loc, adaptor.getTarget(), + adaptor.getTargetSize(), rewriter); + rewriter.replaceOpWithNewOp<IREE::Util::BufferCopyOp>( + copyOp, sourceStorage.buffer, sourceStorage.bufferSize, + adaptor.getSourceOffset(), targetStorage.buffer, + targetStorage.bufferSize, adaptor.getTargetOffset(), + adaptor.getLength()); + return success(); + } +}; + +struct CmdDispatchOpPattern + : public OpConversionPattern<IREE::Stream::CmdDispatchOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = dispatchOp.getLoc(); + + auto callee = dispatchOp->getAttrOfType<SymbolRefAttr>("hal_inline.target"); + if (!callee) { + return rewriter.notifyMatchFailure( + dispatchOp, + "missing hal_inline.target annotation from the " + "--iree-hal-inline-executables pass"); + } + + // The InlineExecutables pass has already done the hard work here; we just + // need to make a function call to the annotated target function with all + // operands/bindings. + SmallVector<Value> callArgs; + llvm::append_range(callArgs, adaptor.getWorkload()); + llvm::append_range(callArgs, adaptor.getUniformOperands()); + SmallVector<Value> bindingBuffers; + SmallVector<Value> bindingOffsets; + for (auto [resource, resourceSize, resourceOffset] : + llvm::zip(adaptor.getResources(), adaptor.getResourceSizes(), + adaptor.getResourceOffsets())) { + auto storage = getResourceStorage(loc, resource, resourceSize, rewriter); + bindingBuffers.push_back(storage.buffer); + bindingOffsets.push_back(resourceOffset); + } + llvm::append_range(callArgs, bindingBuffers); + llvm::append_range(callArgs, bindingOffsets); + llvm::append_range(callArgs, adaptor.getResourceLengths()); + rewriter.replaceOpWithNewOp<func::CallOp>(dispatchOp, callee, TypeRange{}, + callArgs); + return success(); + } +}; + +struct CmdExecuteOpPattern + : public OpConversionPattern<IREE::Stream::CmdExecuteOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdExecuteOp executeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Inline the serial execution region. + rewriter.mergeBlockBefore(&executeOp.getBody().front(), executeOp, + adaptor.getResourceOperands()); + // Immediately resolve the timepoint. + auto resolvedTimepoint = + rewriter.create<arith::ConstantIntOp>(executeOp.getLoc(), 0, 64) + .getResult(); + rewriter.replaceOp(executeOp, resolvedTimepoint); + return success(); + } +}; + +struct CmdSerialOpPattern + : public OpConversionPattern<IREE::Stream::CmdSerialOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdSerialOp serialOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Inline the serial execution region. + rewriter.mergeBlockBefore(&serialOp.getBody().front(), serialOp); + rewriter.eraseOp(serialOp); + return success(); + } +}; + +struct CmdConcurrentOpPattern + : public OpConversionPattern<IREE::Stream::CmdConcurrentOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::CmdConcurrentOp concurrentOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Inline the concurrent execution region. + rewriter.mergeBlockBefore(&concurrentOp.getBody().front(), concurrentOp); + rewriter.eraseOp(concurrentOp); + return success(); + } +}; + +// Annoying we have to have this here, but there's no attribute converter +// equivalent we have access to so that we could do it in a generic way. +struct GlobalTimepointConversionPattern + : public OpConversionPattern<IREE::Util::GlobalOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Util::GlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto initialValue = op.getInitialValue(); + if (!initialValue.has_value()) return failure(); + if (!initialValue->isa<IREE::Stream::TimepointAttr>()) return failure(); + rewriter.updateRootInPlace( + op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); }); + return success(); + } +}; + +struct TimepointImmediateOpPattern + : public OpConversionPattern<IREE::Stream::TimepointImmediateOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TimepointImmediateOp immediateOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(immediateOp, 0, 64); + return success(); + } +}; + +struct TimepointImportOpPattern + : public OpConversionPattern<IREE::Stream::TimepointImportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure( + importOp, + "timepoints are not supported across the ABI with inline execution"); + } +}; + +struct TimepointExportOpPattern + : public OpConversionPattern<IREE::Stream::TimepointExportOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure( + exportOp, + "timepoints are not supported across the ABI with inline execution"); + } +}; + +struct TimepointJoinOpPattern + : public OpConversionPattern<IREE::Stream::TimepointJoinOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TimepointJoinOp joinOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(joinOp, 0, 64); + return success(); + } +}; + +struct TimepointAwaitOpPattern + : public OpConversionPattern<IREE::Stream::TimepointAwaitOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::TimepointAwaitOp awaitOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(awaitOp, adaptor.getResourceOperands()); + return success(); + } +}; + +struct ElideYieldOpPattern : public OpConversionPattern<IREE::Stream::YieldOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Stream::YieldOp yieldOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(yieldOp); + return success(); + } +}; + +} // namespace + +void populateStreamToHALInlinePatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns) { + typeConverter.addConversion( + [=](IREE::Stream::ResourceType type, SmallVectorImpl<Type> &results) { + // Resources are just buffers (no shape/encoding/etc). + // We use !hal.buffer when going across the external ABI boundary but + // otherwise use memrefs. + if (type.getLifetime() == IREE::Stream::Lifetime::External) { + results.push_back(IREE::HAL::BufferType::get(context)); + } else { + results.push_back(IREE::Util::BufferType::get(context)); + } + return success(); + }); + + typeConverter.addConversion( + [=](IREE::Stream::TimepointType type, SmallVectorImpl<Type> &results) { + // TODO(benvanik): model timepoints as semaphores. + // This may become a !hal.semaphore + index, or some !hal.timepoint that + // we then do more analysis on once we know what devices are in use + // where. + results.push_back(IntegerType::get(context, 64)); + return success(); + }); + + patterns.insert<ResourceAllocOpPattern, ResourceAllocaOpPattern, + ResourceDeallocaOpPattern, ResourceSizeOpPattern, + ResourceMapOpPattern, ResourceTryMapOpPattern, + ResourceLoadOpPattern, ResourceStoreOpPattern, + ResourceSubviewOpPattern>(typeConverter, context); + + patterns.insert<TensorImportBufferOpPattern, TensorImportBufferViewOpPattern, + TensorExportBufferOpPattern, TensorExportBufferViewOpPattern, + TensorTraceOpPattern>(typeConverter, context); + + patterns + .insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern, + CmdFillOpPattern, CmdCopyOpPattern, CmdDispatchOpPattern, + CmdExecuteOpPattern, CmdSerialOpPattern, CmdConcurrentOpPattern>( + typeConverter, context); + + patterns.insert<GlobalTimepointConversionPattern>(typeConverter, context); + patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern, + TimepointExportOpPattern, TimepointJoinOpPattern, + TimepointAwaitOpPattern>(typeConverter, context); + + patterns.insert<ElideYieldOpPattern>(typeConverter, context); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.h new file mode 100644 index 0000000..5a6e529 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.h
@@ -0,0 +1,25 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_STREAMTOHALINLINE_CONVERTSTREAMTOHALINLINE_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_STREAMTOHALINLINE_CONVERTSTREAMTOHALINLINE_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// Populates conversion patterns for stream->HAL (inline). +void populateStreamToHALInlinePatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_CONVERSION_STREAMTOHALINLINE_CONVERTSTREAMTOHALINLINE_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD new file mode 100644 index 0000000..f8f5bec --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD
@@ -0,0 +1,31 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "cmd_ops.mlir", + "resource_ops.mlir", + "timepoint_ops.mlir", + "transfer_ops.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/CMakeLists.txt new file mode 100644 index 0000000..2b2b6c5 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/CMakeLists.txt
@@ -0,0 +1,26 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "cmd_ops.mlir" + "resource_ops.mlir" + "timepoint_ops.mlir" + "transfer_ops.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/cmd_ops.mlir new file mode 100644 index 0000000..3058613 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/cmd_ops.mlir
@@ -0,0 +1,131 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// NOTE: memory control ops are currently ignored as we're executing inline and +// assume coherent memory. + +// CHECK-LABEL: @cmdMemoryControl +func.func @cmdMemoryControl(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %fence = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) { + stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1} + stream.cmd.invalidate %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1} + stream.cmd.discard %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1} + } => !stream.timepoint + // CHECK: return %c0 + return %fence : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @cmdFill +// CHECK-SAME: (%[[TARGET:.+]]: !util.buffer, %[[TARGET_SIZE:.+]]: index) +func.func @cmdFill(%target: !stream.resource<transient>, %target_size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128 + %length = arith.constant 128 : index + // CHECK-DAG: %[[VALUE:.+]] = arith.constant 255 + %value = arith.constant 255 : i32 + %fence = stream.cmd.execute with(%target as %target_inner: !stream.resource<transient>{%target_size}) { + // CHECK: util.buffer.fill %[[VALUE]], %[[TARGET]][%c0 for %[[LENGTH]]] : i32 -> !util.buffer{%[[TARGET_SIZE]]} + stream.cmd.fill %value, %target_inner[%c0 for %length] : i32 -> !stream.resource<transient>{%target_size} + } => !stream.timepoint + // CHECK: return %c0 + return %fence : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @cmdCopy +// CHECK-SAME: (%[[SRC:.+]]: !util.buffer, %[[SRC_SIZE:.+]]: index, %[[DST:.+]]: !util.buffer, %[[DST_SIZE:.+]]: index) +func.func @cmdCopy(%src: !stream.resource<transient>, %src_size: index, + %dst: !stream.resource<staging>, %dst_size: index) -> !stream.timepoint { + // CHECK-DAG: %[[SRC_OFFSET:.+]] = arith.constant 100 + %src_offset = arith.constant 100 : index + // CHECK-DAG: %[[DST_OFFSET:.+]] = arith.constant 200 + %dst_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128 + %length = arith.constant 128 : index + %fence = stream.cmd.execute with(%src as %src_inner: !stream.resource<transient>{%src_size}, + %dst as %dst_inner: !stream.resource<staging>{%dst_size}) { + // CHECK: util.buffer.copy %[[SRC]][%[[SRC_OFFSET]]], %[[DST]][%[[DST_OFFSET]]], %[[LENGTH]] : !util.buffer{%[[SRC_SIZE]]} -> !util.buffer{%[[DST_SIZE]]} + stream.cmd.copy %src_inner[%src_offset], %dst_inner[%dst_offset], %length : !stream.resource<transient>{%src_size} -> !stream.resource<staging>{%dst_size} + } => !stream.timepoint + // CHECK: return %c0 + return %fence : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @cmdExecute +func.func @cmdExecute(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %fence = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) { + stream.cmd.concurrent { + // CHECK: util.buffer.copy + stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} + // CHECK: util.buffer.copy + stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} + stream.cmd.serial { + // CHECK: util.buffer.copy + stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} + // CHECK: util.buffer.copy + stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} + } + // CHECK: util.buffer.copy + stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} + } + } => !stream.timepoint + // CHECK: return %c0 + return %fence : !stream.timepoint +} + +// ----- + +// Provided by the iree-hal-inline-executables pass: +func.func private @__dispatch_ex_dispatch( + index, index, // workload[2] + i32, i32, // push_constants[2] + !util.buffer, !util.buffer, // bindingBuffers[2] + index, index, // bindingOffsets[2] + index, index) // bindingLengths[2] + +// NOTE: %buffer0 is transient and will map to a !util.buffer, while +// %buffer1 is external and will map to a !hal.buffer. + +// CHECK-LABEL: @cmdDispatch +// CHECK-SAME: (%[[BUFFER0:.+]]: !util.buffer, %[[BUFFER0_SIZE:.+]]: index, +// CHECK-SAME: %[[BUFFER1:.+]]: !hal.buffer, %[[BUFFER1_SIZE:.+]]: index) +func.func @cmdDispatch(%buffer0: !stream.resource<transient>, %buffer0_size: index, + %buffer1: !stream.resource<external>, %buffer1_size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4_i32 = arith.constant 4 : i32 + %c5_i32 = arith.constant 5 : i32 + %c128 = arith.constant 128 : index + // CHECK: %[[BUFFER0_REL_OFFSET:.+]] = arith.constant 200 + %buffer0_offset = arith.constant 200 : index + // CHECK: %[[BUFFER1_REL_OFFSET:.+]] = arith.constant 300 + %buffer1_offset = arith.constant 300 : index + %fence = stream.cmd.execute with(%buffer0 as %buffer0_inner: !stream.resource<transient>{%buffer0_size}, + %buffer1 as %buffer1_inner: !stream.resource<external>{%buffer1_size}) { + // CHECK: %[[BUFFER1_STORAGE:.+]] = hal_inline.buffer.storage<%[[BUFFER1]] + // CHECK: call @__dispatch_ex_dispatch( + // CHECK-SAME: %c1, %c2, + // CHECK-SAME: %c4_i32, %c5_i32, + // CHECK-SAME: %[[BUFFER0]], %[[BUFFER1_STORAGE]], + // CHECK-SAME: %[[BUFFER0_REL_OFFSET]], %[[BUFFER1_REL_OFFSET]], + // CHECK-SAME: %c128, %c128) + stream.cmd.dispatch @ex::@dispatch[%c1, %c2](%c4_i32, %c5_i32 : i32, i32) { + ro %buffer0_inner[%buffer0_offset for %c128] : !stream.resource<transient>{%buffer0_size}, + wo %buffer1_inner[%buffer1_offset for %c128] : !stream.resource<external>{%buffer1_size} + } attributes { + // From the iree-hal-inline-executables pass: + hal_inline.target = @__dispatch_ex_dispatch + } + } => !stream.timepoint + // CHECK: return %c0 + return %fence : !stream.timepoint +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir new file mode 100644 index 0000000..1d18138 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir
@@ -0,0 +1,137 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// CHECK-LABEL: @resourceAlloc +// CHECK-SAME: (%[[LENGTH:.+]]: index) +func.func @resourceAlloc(%length: index) -> !stream.resource<transient> { + // CHECK: %[[BUFFER:.+]], %[[STORAGE:.+]] = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%[[LENGTH]]} + %result = stream.resource.alloc uninitialized : !stream.resource<transient>{%length} + // CHECK: return %[[STORAGE]] + return %result : !stream.resource<transient> +} + +// ----- + +// CHECK-LABEL: @resourceAlloca +// CHECK-SAME: (%[[LENGTH:.+]]: index) +func.func @resourceAlloca(%length: index) -> (!stream.resource<staging>, !stream.timepoint) { + // CHECK: %[[BUFFER:.+]], %[[STORAGE:.+]] = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%[[LENGTH]]} + %0:2 = stream.resource.alloca uninitialized : !stream.resource<staging>{%length} => !stream.timepoint + // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 + // CHECK: return %[[STORAGE]], %[[IMMEDIATE]] + return %0#0, %0#1 : !stream.resource<staging>, !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @resourceAllocaAwait +// CHECK-SAME: (%[[LENGTH:.+]]: index, %[[TIMEPOINT:.+]]: i64) +func.func @resourceAllocaAwait(%length: index, %await_timepoint: !stream.timepoint) -> (!stream.resource<staging>, !stream.timepoint) { + // CHECK: %[[BUFFER:.+]], %[[STORAGE:.+]] = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%[[LENGTH]]} + %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<staging>{%length} => !stream.timepoint + // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 + // CHECK: return %[[STORAGE]], %[[IMMEDIATE]] + return %0#0, %0#1 : !stream.resource<staging>, !stream.timepoint +} + +// ----- + +// NOTE: we don't do anything with deallocs today but could add a discard op. + +// CHECK-LABEL: @resourceDealloca +func.func @resourceDealloca(%arg0: index, %arg1: !stream.resource<staging>, %arg2: !stream.timepoint) -> !stream.timepoint { + %0 = stream.resource.dealloca %arg1 : !stream.resource<staging>{%arg0} => !stream.timepoint + // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 + // CHECK: return %[[IMMEDIATE]] + return %0 : !stream.timepoint +} + +// ----- + +// NOTE: we don't do anything with deallocs today but could add a discard op. + +// CHECK-LABEL: @resourceDeallocaAwait +func.func @resourceDeallocaAwait(%arg0: index, %arg1: !stream.resource<staging>, %arg2: !stream.timepoint) -> !stream.timepoint { + %0 = stream.resource.dealloca await(%arg2) => %arg1 : !stream.resource<staging>{%arg0} => !stream.timepoint + // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 + // CHECK: return %[[IMMEDIATE]] + return %0 : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @resourceSize +func.func @resourceSize(%arg0: !stream.resource<transient>) -> index { + // CHECK: %[[SIZE:.+]] = util.buffer.size %arg0 + %0 = stream.resource.size %arg0 : !stream.resource<transient> + // CHECK: return %[[SIZE]] + return %0 : index +} + +// ----- + +// CHECK-LABEL: @resourceMap +// CHECK-SAME: (%[[SOURCE:.+]]: !util.buffer) +func.func @resourceMap(%source: !util.buffer) -> !stream.resource<staging> { + // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 + %offset = arith.constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128 + %length = arith.constant 128 : index + // CHECK: %[[SOURCE_SIZE:.+]] = util.buffer.size %[[SOURCE]] : !util.buffer + // CHECK: %[[MAPPING:.+]] = util.buffer.subspan %[[SOURCE]][%[[OFFSET]]] : !util.buffer{%[[SOURCE_SIZE]]} -> !util.buffer{%[[LENGTH]]} + %mapping = stream.resource.map %source[%offset] : !util.buffer -> !stream.resource<staging>{%length} + // CHECK: return %[[MAPPING]] + return %mapping : !stream.resource<staging> +} + +// ----- + +// CHECK-LABEL: @resourceTryMap +// CHECK-SAME: (%[[SOURCE:.+]]: !util.buffer) +func.func @resourceTryMap(%source: !util.buffer) -> (i1, !stream.resource<constant>) { + // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 + %offset = arith.constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128 + %length = arith.constant 128 : index + // CHECK: %[[SOURCE_SIZE:.+]] = util.buffer.size %[[SOURCE]] : !util.buffer + // CHECK: %[[MAPPING:.+]] = util.buffer.subspan %[[SOURCE]][%[[OFFSET]]] : !util.buffer{%[[SOURCE_SIZE]]} -> !util.buffer{%[[LENGTH]]} + // CHECK-DAG: %[[DID_MAP:.+]] = arith.constant true + %did_map, %mapping = stream.resource.try_map %source[%offset] : !util.buffer -> i1, !stream.resource<constant>{%length} + // CHECK: return %[[DID_MAP]], %[[MAPPING]] + return %did_map, %mapping : i1, !stream.resource<constant> +} + +// ----- + +// CHECK-LABEL: @resourceLoad +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[BUFFER_SIZE:.+]]: index, %[[OFFSET:.+]]: index) +func.func @resourceLoad(%resource: !stream.resource<staging>, %resource_size: index, %offset: index) -> i32 { + // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[OFFSET]]] : !util.buffer{%[[BUFFER_SIZE]]} -> i32 + %0 = stream.resource.load %resource[%offset] : !stream.resource<staging>{%resource_size} -> i32 + // CHECK: return %[[VALUE]] + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @resourceStore +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[BUFFER_SIZE:.+]]: index, %[[OFFSET:.+]]: index) +func.func @resourceStore(%resource: !stream.resource<staging>, %resource_size: index, %offset: index) { + // CHECK-DAG: %[[VALUE:.+]] = arith.constant 123 + %value = arith.constant 123 : i32 + // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFFSET]]] : i32 -> !util.buffer{%[[BUFFER_SIZE]]} + stream.resource.store %value, %resource[%offset] : i32 -> !stream.resource<staging>{%resource_size} + return +} + +// ----- + +// CHECK-LABEL: @resourceSubview +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[BUFFER_SIZE:.+]]: index) +func.func @resourceSubview(%resource: !stream.resource<transient>, %resource_size: index) -> !stream.resource<transient> { + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + // CHECK: %[[SUBSPAN:.+]] = util.buffer.subspan %[[BUFFER]][%c128] : !util.buffer{%[[BUFFER_SIZE]]} -> !util.buffer{%c256} + %0 = stream.resource.subview %resource[%c128] : !stream.resource<transient>{%resource_size} -> !stream.resource<transient>{%c256} + // CHECK: return %[[SUBSPAN]] + return %0 : !stream.resource<transient> +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/timepoint_ops.mlir new file mode 100644 index 0000000..aef6a09 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/timepoint_ops.mlir
@@ -0,0 +1,48 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// NOTE: the inline HAL doesn't model timepoints and we just turn them into ints +// that'll eventually get DCE'd. + +// CHECK-LABEL: @rwTimepoint +// CHECK-SAME: = 0 : i64 +util.global private mutable @rwTimepoint = #stream.timepoint<immediate> +// CHECK: func.func @globalTimepoint(%arg0: i64) -> i64 +func.func @globalTimepoint(%arg0: !stream.timepoint) -> !stream.timepoint { + // CHECK: util.global.store %arg0, @rwTimepoint + util.global.store %arg0, @rwTimepoint : !stream.timepoint + // CHECK: %[[VALUE:.+]] = util.global.load @rwTimepoint + %value = util.global.load @rwTimepoint : !stream.timepoint + // CHECK: return %[[VALUE]] + return %value : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @timepointImmediate +func.func @timepointImmediate() -> !stream.timepoint { + // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 + %0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: return %[[TIMEPOINT]] + return %0 : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @timepointJoin +func.func @timepointJoin(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 + %0 = stream.timepoint.join max(%arg0, %arg1) => !stream.timepoint + // CHECK: return %[[TIMEPOINT]] + return %0 : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @timepointAwait +func.func @timepointAwait(%arg0: !stream.timepoint, %arg1: !stream.resource<staging>, %arg2: !stream.resource<*>) -> (!stream.resource<staging>, !stream.resource<*>) { + %c100 = arith.constant 100 : index + %c200 = arith.constant 200 : index + %0:2 = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200} + // CHECK: return %arg1, %arg2 + return %0#0, %0#1 : !stream.resource<staging>, !stream.resource<*> +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/transfer_ops.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/transfer_ops.mlir new file mode 100644 index 0000000..2bbd473 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/test/transfer_ops.mlir
@@ -0,0 +1,49 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-conversion %s | FileCheck %s + +// CHECK-LABEL: @tensorImportBuffer +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer, %[[RESOURCE_SIZE:.+]]: index, %[[DIM:.+]]: index) -> !hal.buffer +func.func @tensorImportBuffer(%buffer: !hal.buffer, %resource_size: index, %dim: index) -> !stream.resource<external> { + %0 = stream.tensor.import %buffer : !hal.buffer -> tensor<?x5xf32>{%dim} in !stream.resource<external>{%resource_size} + // CHECK: return %[[BUFFER]] + return %0 : !stream.resource<external> +} + +// ----- + +// NOTE: buffer view metadata assertions via hal.buffer_view.assert are added +// when lowering into the stream dialect; here we only care about the storage +// buffer itself. + +// CHECK-LABEL: @tensorImportBufferView +// CHECK-SAME: (%[[BUFFER_VIEW:.+]]: !hal.buffer_view, %[[RESOURCE_SIZE:.+]]: index, %[[DIM:.+]]: index) -> !hal.buffer +func.func @tensorImportBufferView(%buffer_view: !hal.buffer_view, %resource_size: index, %dim: index) -> !stream.resource<external> { + // CHECK: %[[BUFFER:.+]] = hal_inline.buffer_view.buffer<%[[BUFFER_VIEW]] : !hal.buffer_view> : !hal.buffer + %0 = stream.tensor.import %buffer_view : !hal.buffer_view -> tensor<?x5xf32>{%dim} in !stream.resource<external>{%resource_size} + // CHECK: return %[[BUFFER]] + return %0 : !stream.resource<external> +} + +// ----- + +// CHECK-LABEL: @tensorExportBuffer +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer, %[[RESOURCE_SIZE:.+]]: index, %[[DIM:.+]]: index) -> !hal.buffer +func.func @tensorExportBuffer(%resource: !stream.resource<external>, %resource_size: index, %dim: index) -> !hal.buffer { + %0 = stream.tensor.export %resource : tensor<?x1x10xf32>{%dim} in !stream.resource<external>{%resource_size} -> !hal.buffer + // CHECK: return %[[BUFFER]] + return %0 : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @tensorExportBufferView +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer, %[[RESOURCE_SIZE:.+]]: index, %[[DIM:.+]]: index) -> !hal.buffer +func.func @tensorExportBufferView(%resource: !stream.resource<external>, %resource_size: index, %dim: index) -> !hal.buffer_view { + // CHECK: %[[BUFFER_VIEW:.+]] = hal_inline.buffer_view.create + // CHECK-SAME: buffer(%[[BUFFER]] : !hal.buffer) + // CHECK-SAME: shape([%[[DIM]], %c1, %c10]) + // CHECK-SAME: type(%c553648160_i32) + // CHECK-SAME: encoding(%c1_i32) + %0 = stream.tensor.export %resource : tensor<?x1x10xf32>{%dim} in !stream.resource<external>{%resource_size} -> !hal.buffer_view + // CHECK: return %[[BUFFER_VIEW]] + return %0 : !hal.buffer_view +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/BUILD new file mode 100644 index 0000000..9cedf21 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/BUILD
@@ -0,0 +1,114 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("@llvm-project//mlir:tblgen.bzl", "td_library") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") +load("//build_tools/bazel:iree_tablegen.bzl", "iree_gentbl_cc_library", "iree_tablegen_doc") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["HALInlineOps.td"]) + +td_library( + name = "td_files", + srcs = enforce_glob( + [ + "HALInlineBase.td", + "HALInlineOps.td", + ], + include = ["*.td"], + ), + deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/IR:td_files", + "//compiler/src/iree/compiler/Dialect/Util/IR:td_files", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +iree_compiler_cc_library( + name = "IR", + srcs = [ + "HALInlineOps.cpp", + ], + hdrs = [ + "HALInlineOps.h", + "HALInlineOps.h.inc", + ], + textual_hdrs = [ + "HALInlineOps.cpp.inc", + ], + deps = [ + ":HALInlineOpsGen", + "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/VM/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:TranslateLib", + ], +) + +iree_compiler_cc_library( + name = "HALInlineDialect", + srcs = ["HALInlineDialect.cpp"], + hdrs = ["HALInlineDialect.h"], + deps = [ + ":IR", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline:hal_inline_imports", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM", + "//compiler/src/iree/compiler/Dialect/VM/Conversion", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +iree_gentbl_cc_library( + name = "HALInlineOpsGen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "HALInlineOps.h.inc", + ), + ( + ["--gen-op-defs"], + "HALInlineOps.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HALInlineOps.td", + deps = [":td_files"], +) + +iree_tablegen_doc( + name = "HALInlineDialecDocGen", + tbl_outs = [ + ( + [ + "--dialect=hal_inline", + "--gen-dialect-doc", + ], + "HALInlineDialect.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HALInlineOps.td", + deps = [":td_files"], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/CMakeLists.txt new file mode 100644 index 0000000..04ba81e --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/CMakeLists.txt
@@ -0,0 +1,79 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IR + HDRS + "HALInlineOps.h" + "HALInlineOps.h.inc" + TEXTUAL_HDRS + "HALInlineOps.cpp.inc" + SRCS + "HALInlineOps.cpp" + DEPS + ::HALInlineOpsGen + LLVMSupport + MLIRArithmeticDialect + MLIRFuncDialect + MLIRIR + MLIRSideEffectInterfaces + MLIRSupport + MLIRTransformUtils + MLIRTranslateLib + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::VM::IR + PUBLIC +) + +iree_cc_library( + NAME + HALInlineDialect + HDRS + "HALInlineDialect.h" + SRCS + "HALInlineDialect.cpp" + DEPS + ::IR + LLVMSupport + MLIRFuncDialect + MLIRIR + MLIRParser + MLIRSupport + MLIRTransformUtils + iree::compiler::Dialect::Modules::HAL::Inline::Conversion::HALInlineToVM + iree::compiler::Dialect::Modules::HAL::Inline::hal_inline_imports + iree::compiler::Dialect::VM::Conversion + PUBLIC +) + +iree_tablegen_library( + NAME + HALInlineOpsGen + TD_FILE + "HALInlineOps.td" + OUTS + --gen-op-decls HALInlineOps.h.inc + --gen-op-defs HALInlineOps.cpp.inc +) + +iree_tablegen_doc( + NAME + HALInlineDialecDocGen + TD_FILE + "HALInlineOps.td" + OUTS + --dialect=hal_inline --gen-dialect-doc HALInlineDialect.md +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineBase.td b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineBase.td new file mode 100644 index 0000000..c8bb807 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineBase.td
@@ -0,0 +1,44 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MODULES_HAL_INLINE_BASE +#define IREE_DIALECT_MODULES_HAL_INLINE_BASE + +include "iree/compiler/Dialect/Util/IR/UtilBase.td" + +//===----------------------------------------------------------------------===// +// IREE HAL inline dialect +//===----------------------------------------------------------------------===// + +def HALInline_Dialect : Dialect { + let name = "hal_inline"; + let cppNamespace = "::mlir::iree_compiler::IREE::HAL::Inline"; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + + let summary = [{ + IREE inline HAL interop runtime module dialect. + }]; + let description = [{ + Low-level dialect for limited in-process ABI interop with the full HAL. + Only operates synchronously, single-threaded, and on host-local buffers. Use + the full HAL for all other cases. + + This dialect can be used alongside the full HAL but is intended for use in + standalone configurations or paired with the `hal_loader` dialect which also + carries the same usage restrictions. + + See `hal_inline.imports.mlir` for the full list of exported functions. + }]; +} + +//===----------------------------------------------------------------------===// +// Base HALInline op classes +//===----------------------------------------------------------------------===// + +class HALInline_Op<string mnemonic, list<Trait> traits = []> : + Op<HALInline_Dialect, mnemonic, !listconcat(traits, [])> {} + +#endif // IREE_DIALECT_MODULES_HAL_INLINE_BASE
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.cpp new file mode 100644 index 0000000..85fec80 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.cpp
@@ -0,0 +1,63 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" + +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALInlineToVM/ConvertHALInlineToVM.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/hal_inline.imports.h" +#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Parser/Parser.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +namespace { + +class HALInlineToVMConversionInterface : public VMConversionDialectInterface { + public: + using VMConversionDialectInterface::VMConversionDialectInterface; + + OwningOpRef<mlir::ModuleOp> parseVMImportModule() const override { + return mlir::parseSourceString<mlir::ModuleOp>( + StringRef(iree_hal_inline_imports_create()->data, + iree_hal_inline_imports_create()->size), + getDialect()->getContext()); + } + + void populateVMConversionPatterns( + SymbolTable &importSymbols, RewritePatternSet &patterns, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter) const override { + conversionTarget.addIllegalDialect<IREE::HAL::Inline::HALInlineDialect>(); + populateHALInlineToVMPatterns(getDialect()->getContext(), conversionTarget, + typeConverter, importSymbols, patterns); + } +}; + +} // namespace + +HALInlineDialect::HALInlineDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get<HALInlineDialect>()) { + addInterfaces<HALInlineToVMConversionInterface>(); + +#define GET_OP_LIST + addOperations< +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.cpp.inc" + >(); +} + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h new file mode 100644 index 0000000..b53b70d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h
@@ -0,0 +1,31 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEDIALECT_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +class HALInlineDialect : public Dialect { + public: + explicit HALInlineDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "hal_inline"; } +}; + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEDIALECT_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.cpp new file mode 100644 index 0000000..e8aa7e4 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.cpp
@@ -0,0 +1,211 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" + +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/SMLoc.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.allocate +//===----------------------------------------------------------------------===// + +void BufferAllocateOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "buffer"); + setNameFn(getStorage(), "storage"); +} + +Value BufferAllocateOp::getOperandSize(unsigned idx) { return {}; } + +Value BufferAllocateOp::getResultSize(unsigned idx) { + return getAllocationSize(); +} + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.allocate.initialized +//===----------------------------------------------------------------------===// + +void BufferAllocateInitializedOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "buffer"); + setNameFn(getStorage(), "storage"); +} + +Value BufferAllocateInitializedOp::getOperandSize(unsigned idx) { return {}; } + +Value BufferAllocateInitializedOp::getResultSize(unsigned idx) { + return getLength(); +} + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.wrap +//===----------------------------------------------------------------------===// + +void BufferWrapOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "mapped"); +} + +Value BufferWrapOp::getOperandSize(unsigned idx) { return {}; } + +Value BufferWrapOp::getResultSize(unsigned idx) { return getLength(); } + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.subspan +//===----------------------------------------------------------------------===// + +void BufferSubspanOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "buffer"); +} + +Value BufferSubspanOp::getOperandSize(unsigned idx) { return getLength(); } + +Value BufferSubspanOp::getResultSize(unsigned idx) { return getLength(); } + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.byte_length +//===----------------------------------------------------------------------===// + +void BufferLengthOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "length"); +} + +OpFoldResult BufferLengthOp::fold(ArrayRef<Attribute> operands) { + Operation *op = this->getOperation(); + return IREE::Util::SizeAwareTypeInterface::findSizeValue( + getBuffer(), op->getBlock(), Block::iterator(op)); +} + +//===----------------------------------------------------------------------===// +// hal_inline.buffer.storage +//===----------------------------------------------------------------------===// + +void BufferStorageOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "storage"); +} + +OpFoldResult BufferStorageOp::fold(ArrayRef<Attribute> operands) { + auto *definingOp = getBuffer().getDefiningOp(); + if (!definingOp) return {}; + if (auto sourceOp = + dyn_cast_or_null<IREE::HAL::Inline::BufferAllocateOp>(definingOp)) { + return sourceOp.getStorage(); + } else if (auto sourceOp = dyn_cast_or_null< + IREE::HAL::Inline::BufferAllocateInitializedOp>(definingOp)) { + return sourceOp.getStorage(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// hal_inline.buffer_view.create +//===----------------------------------------------------------------------===// + +void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, + Value buffer, int32_t elementType, + int32_t encodingType, ValueRange shape) { + build(builder, state, buffer, + builder.createOrFold<arith::ConstantIntOp>(state.location, elementType, + 32), + builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType, + 32), + shape); +} + +void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, + Value buffer, Value elementType, + Value encodingType, ValueRange shape) { + state.addOperands({buffer, elementType, encodingType}); + state.addOperands(shape); + state.addTypes({BufferViewType::get(builder.getContext())}); +} + +void BufferViewCreateOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "view"); +} + +//===----------------------------------------------------------------------===// +// hal_inline.buffer_view.buffer +//===----------------------------------------------------------------------===// + +void BufferViewBufferOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + setNameFn(getResult(), "buffer"); +} + +namespace { + +/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in +/// the same scope and we know the origin buffer. +struct SkipBufferViewBufferOp : public OpRewritePattern<BufferViewBufferOp> { + using OpRewritePattern<BufferViewBufferOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(BufferViewBufferOp op, + PatternRewriter &rewriter) const override { + if (auto createOp = dyn_cast_or_null<BufferViewCreateOp>( + op.getBufferView().getDefiningOp())) { + rewriter.replaceOp(op, createOp.getBuffer()); + return success(); + } + return failure(); + } +}; + +} // namespace + +void BufferViewBufferOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert<SkipBufferViewBufferOp>(context); +} + +//===----------------------------------------------------------------------===// +// hal_inline.device.query +//===----------------------------------------------------------------------===// + +LogicalResult DeviceQueryOp::verify() { + DeviceQueryOp op = *this; + if (op.getDefaultValue().has_value()) { + if (auto typedDefaultValue = op.getDefaultValue()->dyn_cast<TypedAttr>()) { + if (typedDefaultValue.getType() != op.getValue().getType()) { + return op.emitOpError() + << "type mismatch between result and default value"; + } + } + } + return success(); +} + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen definitions (intentionally last) +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h new file mode 100644 index 0000000..97af7e8 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h
@@ -0,0 +1,26 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEOPS_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEOPS_H_ + +#include <cstdint> + +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/IR/UtilTraits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h.inc" // IWYU pragma: keep + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_IR_HALINLINEOPS_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.td b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.td new file mode 100644 index 0000000..134575f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.td
@@ -0,0 +1,458 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MODULES_HAL_INLINE_OPS +#define IREE_DIALECT_MODULES_HAL_INLINE_OPS + +include "iree/compiler/Dialect/HAL/IR/HALBase.td" +include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineBase.td" +include "iree/compiler/Dialect/Util/IR/UtilAttrs.td" +include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class HALInline_PureOp<string mnemonic, list<Trait> traits = []> : + HALInline_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +//===----------------------------------------------------------------------===// +// !hal.buffer / iree_hal_buffer_t +//===----------------------------------------------------------------------===// + +def HALInline_BufferAllocateOp : HALInline_Op<"buffer.allocate", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<Util_SizeAwareOp>, +]> { + let summary = [{empty buffer allocation operation}]; + let description = [{ + Allocates a buffer of the given size. + The size of the buffer returned may be larger than the requested size if the + allocator has specific alignment requirements or minimum allocation sizes. + }]; + + let arguments = (ins + HAL_DeviceSize:$minimum_alignment, + HAL_DeviceSize:$allocation_size + ); + let results = (outs + HAL_Buffer:$result, + Util_BufferType:$storage + ); + + let assemblyFormat = [{ + `alignment` `(` $minimum_alignment `)` + `:` custom<SizeAwareType>(type($result), $allocation_size) `in` type($storage) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferAllocateInitializedOp : HALInline_Op<"buffer.allocate.initialized", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<Util_SizeAwareOp>, +]> { + let summary = [{buffer allocation with cloning}]; + let description = [{ + Allocates a buffer with a copy of the provided contents. + }]; + + let arguments = (ins + HAL_DeviceSize:$minimum_alignment, + Util_BufferType:$source, + HAL_DeviceSize:$offset, + HAL_DeviceSize:$length + ); + let results = (outs + HAL_Buffer:$result, + Util_BufferType:$storage + ); + + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` `` `[` $offset `,` $length `]` + `alignment` `(` $minimum_alignment `)` + `:` custom<SizeAwareType>(type($result), ref($length)) `in` type($storage) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferWrapOp : HALInline_Op<"buffer.wrap", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<Util_SizeAwareOp>, +]> { + let summary = [{host buffer wrapping operation}]; + let description = [{ + Tries wrapping a !hal.buffer around host memory backed by the given byte + buffer. + }]; + + let arguments = (ins + Util_BufferType:$source, + HAL_DeviceSize:$offset, + HAL_DeviceSize:$length + ); + let results = (outs + HAL_Buffer:$result + ); + + // TODO(benvanik): change type/usage to ref params. + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` `` `[` $offset `,` $length `]` + `:` type($result) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferSubspanOp : HALInline_PureOp<"buffer.subspan", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<Util_SizeAwareOp>, +]> { + let summary = [{buffer subspan operation}]; + let description = [{ + Returns a reference to a subspan of the buffer. + }]; + + let arguments = (ins + HAL_BufferType:$source_buffer, + HAL_DeviceSize:$source_offset, + HAL_DeviceSize:$length + ); + let results = (outs + HAL_BufferType:$result + ); + + let assemblyFormat = [{ + `<` $source_buffer `:` type($source_buffer) `>` + `` `[` $source_offset `,` $length `]` + `:` type($result) + attr-dict-with-keyword + }]; + + // TODO(benvanik): folder to elide when offset is 0 and length is all. +} + +def HALInline_BufferLengthOp : HALInline_PureOp<"buffer.length", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, +]> { + let summary = [{buffer byte length accessor}]; + let description = [{ + Returns the allocated size of a buffer in bytes. + May be less than the underlying buffer allocation if this is a subspan or + view into another buffer. + }]; + + let arguments = (ins + HAL_BufferType:$buffer + ); + let results = (outs + HAL_DeviceSize:$result + ); + + let assemblyFormat = [{ + `<` $buffer `:` type($buffer) `>` + `:` type($result) + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins "Value":$buffer), + [{ + build($_builder, $_state, $_builder.getIndexType(), buffer); + }]>, + ]; + + let hasFolder = 1; +} + +def HALInline_BufferStorageOp : HALInline_PureOp<"buffer.storage", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, +]> { + let summary = [{buffer backing storage accessor}]; + let description = [{ + Returns the host backing storage of the HAL buffer as a subspan limited to + to the buffer's logical range (meaning that byte 0 of the returned buffer is + byte 0 of the HAL buffer). + }]; + + let arguments = (ins + HAL_BufferType:$buffer + ); + let results = (outs + Util_BufferType:$storage + ); + + let assemblyFormat = [{ + `<` $buffer `:` type($buffer) `>` + `:` type($storage) + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins "Value":$buffer), + [{ + build($_builder, $_state, $_builder.getType<IREE::Util::BufferType>(), buffer); + }]>, + ]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// !hal.buffer_view / iree_hal_buffer_view_t +//===----------------------------------------------------------------------===// + +def HALInline_BufferViewCreateOp : HALInline_PureOp<"buffer_view.create", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, +]> { + let summary = [{buffer view reference initializer}]; + let description = [{ + Creates a reference to a buffer with a particular shape and element type. + The buffer is not copied and both the original and view references must be + synchronized. This makes it easier to associate commonly-carried metadata + along with the contents. + }]; + + let arguments = (ins + HAL_BufferType:$buffer, + HAL_ElementType:$element_type, + HAL_EncodingType:$encoding_type, + HAL_Shape:$shape + ); + let results = (outs + HAL_BufferView:$result + ); + + let assemblyFormat = [{ + `buffer` `(` $buffer `:` type($buffer) `)` + `shape` `(` `[` $shape `]` `)` + `type` `(` $element_type `)` + `encoding` `(` $encoding_type `)` + `:` type($result) + attr-dict-with-keyword + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "Value":$buffer, + "int32_t":$elementType, + "int32_t":$encodingType, + "ValueRange":$shape + )>, + OpBuilder<(ins + "Value":$buffer, + "Value":$elementType, + "Value":$encodingType, + "ValueRange":$shape + )>, + ]; +} + +def HALInline_BufferViewAssertOp : HALInline_Op<"buffer_view.assert"> { + let summary = [{buffer view contents assertion}]; + let description = [{ + Asserts that the buffer view contains a data compatible tensor with the + given encoding. Program execution will abort as if `std.assert` had been + used. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view, + StrAttr:$message, + HAL_ElementType:$element_type, + HAL_EncodingType:$encoding_type, + HAL_Shape:$shape + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `message` `(` $message `)` + `shape` `(` `[` $shape `]` `)` + `type` `(` $element_type `)` + `encoding` `(` $encoding_type `)` + attr-dict-with-keyword + }]; + + // TODO(benvanik): fold away when we know some properties of the buffer view + // (such as when we create it ourselves earlier on) or we've already asserted. +} + +def HALInline_BufferViewBufferOp : HALInline_PureOp<"buffer_view.buffer", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, +]> { + let summary = [{buffer view buffer accessor}]; + let description = [{ + Returns the buffer backing this view's contents. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view + ); + let results = (outs + HAL_BufferType:$result + ); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `:` type($result) + attr-dict-with-keyword + }]; + + let hasCanonicalizer = 1; +} + +def HALInline_BufferViewElementTypeOp : HALInline_PureOp<"buffer_view.element_type"> { + let summary = [{buffer view element type query}]; + let description = [{ + Returns the element type of the buffer view. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view + ); + let results = (outs + HAL_ElementType:$result + ); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `:` type($result) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferViewEncodingTypeOp : HALInline_PureOp<"buffer_view.encoding_type"> { + let summary = [{buffer view encoding type query}]; + let description = [{ + Returns the encoding type of the buffer view. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view + ); + let results = (outs + HAL_EncodingType:$result + ); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `:` type($result) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferViewRankOp : HALInline_PureOp<"buffer_view.rank"> { + let summary = [{buffer view rank query}]; + let description = [{ + Returns the rank of the buffer view. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view + ); + let results = (outs + HAL_Dim:$result + ); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `:` type($result) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferViewDimOp : HALInline_PureOp<"buffer_view.dim"> { + let summary = [{buffer view dimension value query}]; + let description = [{ + Returns the value of the given dimension. + }]; + + let arguments = (ins + HAL_BufferView:$buffer_view, + IndexAttr:$index + ); + let results = (outs + HAL_Dim:$result + ); + + let assemblyFormat = [{ + `<` $buffer_view `:` type($buffer_view) `>` + `` `[` $index `]` + `:` type($result) + attr-dict-with-keyword + }]; +} + +def HALInline_BufferViewTraceOp : HALInline_Op<"buffer_view.trace", []> { + let summary = [{trace value(s) operation}]; + let description = [{ + Traces out to a runtime trace sink (console, log file, etc) the given buffer + views and titles them with the given key. The key is informational only and + useful for titling/marking specific sets of buffers for easier searching. + }]; + + let arguments = (ins + StrAttr:$key, + Variadic<HAL_BufferView>:$operands + ); + + let assemblyFormat = [{ + $operands `:` type($operands) + attr-dict-with-keyword + }]; +} + +//===----------------------------------------------------------------------===// +// !hal.device / iree_hal_device_t +//===----------------------------------------------------------------------===// + +def HALInline_DeviceQueryOp : + HALInline_PureOp<"device.query"> { + let summary = [{returns a runtime configuration parameter from the device}]; + let description = [{ + Queries a device configuration parameter with the given key. + Returns a status indicating whether the pair was recognized/available and if + it was the value converted to the specified type. Queries must return the + same value for the lifetime of the module though may vary from run to run. + + This is roughly equivalent to the `sysconf` linux syscall + (https://man7.org/linux/man-pages/man3/sysconf.3.html) in that the exact + set of keys available and their interpretation is target-dependent. If there + is a HAL match attribute (`#hal.device.match.*`) or op + (`hal.device.match.*`) prefer to use that in order to get compile-time + propagation when the target is specified and elide the runtime query and + get compile-time verification when a runtime query is required. + + Users of the op must check the `ok` result before using the value as what + set of keys is available may change over time. If in doubt: don't use this. + Each key used adds additional versioning and testing complexity as runtime + code path changes will explode combinatorially and should be treated with as + much care as a binary file format change. Keys should be prefixed with `ex.` + when experimental indicating that they are not expected to be present + forever; all non-experimental keys should be vetted. + + Well-known keys: (none yet) + }]; + + let arguments = (ins + StrAttr:$category, + StrAttr:$key, + OptionalAttr<AnyAttr>:$default_value + ); + let results = (outs + I1:$ok, + AnyType:$value + ); + + let assemblyFormat = [{ + `key` `(` $category `:` `` `:` $key `)` + `:` type($ok) `,` type($value) + (`=` $default_value^)? + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + +#endif // IREE_DIALECT_MODULES_HAL_INLINE_OPS
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/BUILD new file mode 100644 index 0000000..98a6f8f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "buffer_folding.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/CMakeLists.txt new file mode 100644 index 0000000..e265ec7 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/CMakeLists.txt
@@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "buffer_folding.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/buffer_folding.mlir new file mode 100644 index 0000000..eead77f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR/test/buffer_folding.mlir
@@ -0,0 +1,43 @@ +// RUN: iree-opt --split-input-file --canonicalize -cse %s | iree-opt --allow-unregistered-dialect --split-input-file | FileCheck %s + +// CHECK-LABEL: func @fold_buffer_length +// CHECK-SAME: (%[[LENGTH:.+]]: index) +func.func @fold_buffer_length(%length: index) -> index { + %c64 = arith.constant 64 : index + %buffer, %storage = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%length} in !util.buffer + // CHECK-NOT: hal_inline.buffer.length + %queried_length = hal_inline.buffer.length<%buffer : !hal.buffer> : index + // CHECK: return %[[LENGTH]] + return %queried_length : index +} + +// ----- + +// CHECK-LABEL: func @fold_buffer_storage +func.func @fold_buffer_storage(%length: index) -> !util.buffer { + %c64 = arith.constant 64 : index + // CHECK: %[[BUFFER:.+]], %[[STORAGE:.+]] = hal_inline.buffer.allocate + %buffer, %storage = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%length} in !util.buffer + // CHECK-NOT: hal_inline.buffer.storage + %queried_storage = hal_inline.buffer.storage<%buffer : !hal.buffer> : !util.buffer + // CHECK: return %[[STORAGE]] + return %queried_storage : !util.buffer +} + +// ----- + +// CHECK-LABEL: func @skip_buffer_view_buffer +// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer) +func.func @skip_buffer_view_buffer(%buffer: !hal.buffer) -> !hal.buffer { + %c1 = arith.constant 1 : i32 + %c10 = arith.constant 10 : index + %c11 = arith.constant 11 : index + %c32 = arith.constant 32 : i32 + %view = hal_inline.buffer_view.create buffer(%buffer : !hal.buffer) + shape([%c10, %c11]) + type(%c32) + encoding(%c1) : !hal.buffer_view + %view_buffer = hal_inline.buffer_view.buffer<%view : !hal.buffer_view> : !hal.buffer + // CHECK: return %[[BUFFER]] + return %view_buffer : !hal.buffer +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/BUILD new file mode 100644 index 0000000..6551edb --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/BUILD
@@ -0,0 +1,83 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") +load("//build_tools/bazel:iree_tablegen.bzl", "iree_gentbl_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "Transforms", + srcs = [ + "Conversion.cpp", + "InlineExecutables.cpp", + "Passes.cpp", + ], + hdrs = ["Passes.h"], + deps = [ + ":PassHeaders", + "//compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL", + "//compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL", + "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", + "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Transforms", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR:HALInlineDialect", + "//compiler/src/iree/compiler/Dialect/Stream/IR", + "//compiler/src/iree/compiler/Dialect/Util/Conversion", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/Util/Transforms", + "//compiler/src/iree/compiler/Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithmeticTransforms", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathTransforms", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +iree_compiler_cc_library( + name = "PassHeaders", + hdrs = [ + "PassDetail.h", + "Passes.h", + "Passes.h.inc", + ], + deps = [ + ":PassesIncGen", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +iree_gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["--gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/CMakeLists.txt new file mode 100644 index 0000000..c33d9fe --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/CMakeLists.txt
@@ -0,0 +1,79 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + Transforms + HDRS + "Passes.h" + SRCS + "Conversion.cpp" + "InlineExecutables.cpp" + "Passes.cpp" + DEPS + ::PassHeaders + LLVMSupport + MLIRAffineDialect + MLIRArithmeticDialect + MLIRArithmeticTransforms + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRMathDialect + MLIRMathTransforms + MLIRMemRefDialect + MLIRPass + MLIRSCFDialect + MLIRSCFToControlFlow + MLIRSupport + MLIRTransforms + iree::compiler::Dialect::HAL::Conversion::StandardToHAL + iree::compiler::Dialect::HAL::Conversion::UtilToHAL + iree::compiler::Dialect::HAL::IR::HALDialect + iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Transforms + iree::compiler::Dialect::Modules::HAL::Inline::Conversion::HALToHALInline + iree::compiler::Dialect::Modules::HAL::Inline::Conversion::StreamToHALInline + iree::compiler::Dialect::Modules::HAL::Inline::IR + iree::compiler::Dialect::Modules::HAL::Inline::IR::HALInlineDialect + iree::compiler::Dialect::Stream::IR + iree::compiler::Dialect::Util::Conversion + iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::Util::Transforms + iree::compiler::Utils + PUBLIC +) + +iree_cc_library( + NAME + PassHeaders + HDRS + "PassDetail.h" + "Passes.h" + "Passes.h.inc" + DEPS + ::PassesIncGen + MLIRPass + MLIRTransforms + PUBLIC +) + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Conversion.cpp new file mode 100644 index 0000000..a0613f9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Conversion.cpp
@@ -0,0 +1,103 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h" +#include "iree/compiler/Dialect/HAL/Conversion/UtilToHAL/ConvertUtilToHAL.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +// Runs conversion with registered input dialects. +class ConversionPass : public ConversionBase<ConversionPass> { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<IREE::Util::UtilDialect, IREE::HAL::HALDialect, + IREE::HAL::Inline::HALInlineDialect, + mlir::arith::ArithmeticDialect>(); + } + + void runOnOperation() override { + auto *context = &getContext(); + + // Ensure all input dialects go away. + ConversionTarget conversionTarget(*context); + conversionTarget + .addLegalDialect<mlir::func::FuncDialect, mlir::scf::SCFDialect, + mlir::arith::ArithmeticDialect>(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + + // Pass-through. + typeConverter.addConversion([](IndexType type) { return type; }); + typeConverter.addConversion([](IntegerType type) { return type; }); + typeConverter.addConversion([](FloatType type) { return type; }); + typeConverter.addConversion( + [](IREE::Util::BufferType type) { return type; }); + + // Convert stream into `hal_inline`, which mostly entails ignoring ops. + conversionTarget.addLegalDialect<IREE::HAL::Inline::HALInlineDialect>(); + populateStreamToHALInlinePatterns(context, conversionTarget, typeConverter, + patterns); + + // Convert some common things into HAL, reusing those conversions. + populateStandardToHALPatterns(context, conversionTarget, typeConverter, + patterns); + populateUtilToHALPatterns(context, conversionTarget, typeConverter, + patterns); + + // Convert any full `hal` ops into `hal_inline` ops. + conversionTarget.addIllegalDialect<IREE::HAL::HALDialect>(); + populateHALToHALInlinePatterns(context, conversionTarget, typeConverter, + patterns); + + // Generic conversion. + conversionTarget.addLegalDialect<IREE::Util::UtilDialect>(); + populateUtilConversionPatterns(context, conversionTarget, typeConverter, + patterns); + populateGenericStructuralConversionPatterns(context, conversionTarget, + typeConverter, patterns); + + if (failed(applyPartialConversion(getOperation(), conversionTarget, + std::move(patterns)))) { + getOperation().emitError() + << "conversion to the hal_inline dialect failed"; + return signalPassFailure(); + } + } +}; + +std::unique_ptr<OperationPass<mlir::ModuleOp>> createConversionPass() { + return std::make_unique<ConversionPass>(); +} + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/InlineExecutables.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/InlineExecutables.cpp new file mode 100644 index 0000000..cfa6af2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
@@ -0,0 +1,418 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Utils/IndexSet.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +class InlineExecutablesPass + : public InlineExecutablesBase<InlineExecutablesPass> { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert<IREE::Util::UtilDialect, IREE::HAL::HALDialect, + IREE::HAL::Inline::HALInlineDialect, arith::ArithmeticDialect, + func::FuncDialect, scf::SCFDialect>(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Inline variants and produce a function map. + DenseMap<Attribute, Attribute> exportToFuncMap; + SymbolTableCollection symbolTables; + for (auto executableOp : llvm::make_early_inc_range( + moduleOp.getOps<IREE::HAL::ExecutableOp>())) { + // Inline each variant. + for (auto variantOp : + executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { + if (failed(inlineVariant(executableOp, variantOp, moduleOp, + exportToFuncMap, symbolTables))) { + return signalPassFailure(); + } + } + + // Drop executable after information has been extracted and the workgroup + // code has been inlined. + executableOp.erase(); + } + + // Annotate all dispatches with the target function. + for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) { + funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) { + // Specify new target function that conversion can use to make the call. + auto targetFuncName = + exportToFuncMap[dispatchOp.getEntryPoint()].cast<StringAttr>(); + assert(targetFuncName && "missing mapping"); + dispatchOp->setAttr("hal_inline.target", + FlatSymbolRefAttr::get(targetFuncName)); + }); + } + } + + LogicalResult inlineVariant(IREE::HAL::ExecutableOp executableOp, + IREE::HAL::ExecutableVariantOp variantOp, + mlir::ModuleOp targetModuleOp, + DenseMap<Attribute, Attribute> &exportToFuncMap, + SymbolTableCollection &symbolTables) { + auto innerModuleOp = variantOp.getInnerModule(); + auto innerSymbolTable = symbolTables.getSymbolTable(innerModuleOp); + auto innerModuleBuilder = OpBuilder::atBlockEnd(innerModuleOp.getBody()); + + // We want to merge the module ahead of the exported functions to ensure + // initializer order is preserved. + OpBuilder targetModuleBuilder(executableOp); + + // Build each dispatch function wrapper. + auto indexType = innerModuleBuilder.getIndexType(); + auto i32Type = innerModuleBuilder.getI32Type(); + auto bufferType = innerModuleBuilder.getType<IREE::Util::BufferType>(); + for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) { + // Build dispatch function signature that the stream.cmd.dispatch ops will + // map to. + auto layoutAttr = exportOp.getLayout(); + size_t totalBindingCount = 0; + for (auto setLayout : layoutAttr.getSetLayouts()) { + totalBindingCount += setLayout.getBindings().size(); + } + SmallVector<Type> inputTypes; + inputTypes.append(exportOp.getWorkgroupCountBody()->getNumArguments() - 1, + indexType); // workload + inputTypes.append(layoutAttr.getPushConstants(), i32Type); + inputTypes.append(totalBindingCount, bufferType); // buffers + inputTypes.append(totalBindingCount, indexType); // offsets + inputTypes.append(totalBindingCount, indexType); // lengths + auto dispatchFuncType = + innerModuleBuilder.getFunctionType(inputTypes, {}); + + // Create the function and insert into the module. + auto dispatchFuncOp = func::FuncOp::create( + exportOp.getLoc(), + ("__dispatch_" + executableOp.getName() + "_" + exportOp.getName()) + .str(), + dispatchFuncType); + dispatchFuncOp.setPrivate(); + innerSymbolTable.insert(dispatchFuncOp, + innerModuleBuilder.getInsertionPoint()); + innerModuleBuilder.setInsertionPointAfter(dispatchFuncOp); + + // Build the dispatch function by calling the target function in a loop. + auto bodyFuncOp = + innerSymbolTable.lookup<func::FuncOp>(exportOp.getName()); + if (bodyFuncOp.isPublic()) { + if (failed(rewriteWorkgroupSignature(layoutAttr, totalBindingCount, + bodyFuncOp))) { + return failure(); + } + bodyFuncOp.setPrivate(); // so we only do it once + } + buildDispatchFunc(exportOp, layoutAttr, totalBindingCount, bodyFuncOp, + dispatchFuncOp); + + // Map from what the stream.cmd.dispatch ops is using to the new function. + auto exportTargetAttr = + SymbolRefAttr::get(executableOp.getNameAttr(), + {SymbolRefAttr::get(exportOp.getNameAttr())}); + exportToFuncMap[exportTargetAttr] = dispatchFuncOp.getNameAttr(); + } + + // Merge the source executable module into the target host module. + if (failed(mergeModuleInto(innerModuleOp, targetModuleOp, + targetModuleBuilder))) { + return failure(); + } + + return success(); + } + + // Rewrites a workgroup body function signature to a flattened list. + // + // Body (as translated): + // (local_memory, [constants], [bindings], + // workgroup_x, workgroup_y, workgroup_z, + // workgroup_size_x, workgroup_size_y, workgroup_size_z, + // workgroup_count_x, workgroup_count_y, workgroup_count_z) + // + // Body after rewrite: + // (local_memory, constants..., bindings..., + // workgroup_x, workgroup_y, workgroup_z, + // workgroup_size_x, workgroup_size_y, workgroup_size_z, + // workgroup_count_x, workgroup_count_y, workgroup_count_z) + // + // To make this process easier and lighten the load on the downstream passes + // we muck with the ABI to pass a flattened list of constants and bindings. + // Whenever better IPO and util.list optimizations are added we could back + // this out to keep things vanilla and have fewer places making assumptions + // about the function signatures. + LogicalResult rewriteWorkgroupSignature( + IREE::HAL::ExecutableLayoutAttr layoutAttr, size_t totalBindingCount, + func::FuncOp bodyFuncOp) { + auto *entryBlock = &bodyFuncOp.front(); + auto builder = OpBuilder::atBlockBegin(entryBlock); + auto indexType = builder.getIndexType(); + auto i32Type = builder.getI32Type(); + auto bufferType = builder.getType<IREE::Util::BufferType>(); + + // There may be nicer ways of doing this but I can't find them. + // We build a new list of argument types and insert them as we go. This lets + // us map the arguments over and replace usage such that by the end we can + // slice off the original arguments as they'll have no more uses. + unsigned originalArgCount = entryBlock->getNumArguments(); + SmallVector<Type> newArgTypes; + unsigned argOffset = 0; + + // Local memory is carried across as-is. + auto localMemoryArg = entryBlock->getArgument(argOffset++); + newArgTypes.push_back(bufferType); + localMemoryArg.replaceAllUsesWith( + entryBlock->addArgument(bufferType, localMemoryArg.getLoc())); + + // Expand push constants by replacing buffer accesses with the flattened + // args. + newArgTypes.append(layoutAttr.getPushConstants(), i32Type); + auto constantBuffer = entryBlock->getArgument(argOffset++); + SmallVector<Value> constantArgs; + for (unsigned i = 0; i < layoutAttr.getPushConstants(); ++i) { + constantArgs.push_back( + entryBlock->addArgument(i32Type, constantBuffer.getLoc())); + } + if (failed(replaceBufferAccesses(constantBuffer, constantArgs))) { + return failure(); + } + + // Expand buffer list by replacing list accesses with the flattened args. + newArgTypes.append(totalBindingCount, bufferType); + auto bindingList = entryBlock->getArgument(argOffset++); + SmallVector<Value> bindingArgs; + for (unsigned i = 0; i < totalBindingCount; ++i) { + bindingArgs.push_back( + entryBlock->addArgument(bufferType, bindingList.getLoc())); + } + if (failed(replaceListAccesses(bindingList, bindingArgs))) { + return failure(); + } + + // Take care of the workgroup id/size/count tuples. + for (unsigned i = 0; i < 3 * /*xyz=*/3; ++i) { + newArgTypes.push_back(indexType); + auto oldArg = entryBlock->getArgument(argOffset++); + oldArg.replaceAllUsesWith( + entryBlock->addArgument(indexType, oldArg.getLoc())); + } + + // Erase the original args. + for (unsigned i = 0; i < originalArgCount; ++i) { + entryBlock->eraseArgument(0); + } + + // Update function signature to reflect the entry block args. + bodyFuncOp.setType( + builder.getFunctionType(newArgTypes, bodyFuncOp.getResultTypes())); + + return success(); + } + + // Replaces trivial constant index accesses to a buffer with their values. + // This is an extremely poor optimization that we should remove if buffer + // ever gets store-load forwarding - we could just create the buffer, store + // the elements, and let that take care of the rest. Today it doesn't do that. + LogicalResult replaceBufferAccesses(Value buffer, ValueRange elements) { + for (auto user : llvm::make_early_inc_range(buffer.getUsers())) { + if (auto sizeOp = dyn_cast<IREE::Util::BufferSizeOp>(user)) { + // Ignored but we need to get rid of it. + // TODO(benvanik): see if we can allow this through; today it will pin + // the function argument (constants most likely) and cause us to fail to + // remove it later on. + Value dummySize = OpBuilder(sizeOp).create<arith::ConstantIndexOp>( + sizeOp.getLoc(), 0xCAFEF00D); + sizeOp.replaceAllUsesWith(dummySize); + sizeOp.erase(); + continue; + } else if (auto loadOp = dyn_cast<IREE::Util::BufferLoadOp>(user)) { + APInt index; + if (matchPattern(loadOp.getSourceOffset(), m_ConstantInt(&index))) { + loadOp.replaceAllUsesWith( + elements[index.getSExtValue() / sizeof(uint32_t)]); + loadOp.erase(); + continue; + } else { + return loadOp.emitOpError( + "unhandled dynamic buffer access; must be static"); + } + } else if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { + if (loadOp.indices().size() != 1) { + return loadOp.emitOpError( + "expected memrefs to have been flattened before inlining " + "executables"); + } + APInt index; + if (matchPattern(loadOp.indices()[0], m_ConstantInt(&index))) { + loadOp.replaceAllUsesWith(elements[index.getSExtValue()]); + loadOp.erase(); + continue; + } else { + return loadOp.emitOpError( + "unhandled dynamic buffer access; must be static"); + } + } else { + return user->emitOpError( + "unhandled buffer access op; only loads are supported"); + } + } + return success(); + } + + // Replaces trivial constant index accesses to a list with their values. + // util.list store-load forwarding could do this instead. + LogicalResult replaceListAccesses(Value list, ValueRange elements) { + for (auto user : llvm::make_early_inc_range(list.getUsers())) { + if (auto getOp = dyn_cast<IREE::Util::ListGetOp>(user)) { + APInt index; + if (matchPattern(getOp.getIndex(), m_ConstantInt(&index))) { + getOp.replaceAllUsesWith(elements[index.getSExtValue()]); + getOp.erase(); + continue; + } else { + return getOp.emitOpError( + "unhandled dynamic list access; must be static"); + } + } else { + return user->emitOpError( + "unhandled list access op; only gets are supported"); + } + } + return success(); + } + + // Builds a function that calls a workgroup body and marshals arguments. + // + // Incoming: + // (workload..., push_constants..., + // binding_buffers..., binding_offsets..., binding_lengths...) + // Body (as translated): + // (local_memory, [constants], [bindings], + // workgroup_x, workgroup_y, workgroup_z, + // workgroup_size_x, workgroup_size_y, workgroup_size_z, + // workgroup_count_x, workgroup_count_y, workgroup_count_z) + void buildDispatchFunc(IREE::HAL::ExecutableExportOp exportOp, + IREE::HAL::ExecutableLayoutAttr layoutAttr, + size_t totalBindingCount, func::FuncOp bodyFuncOp, + func::FuncOp dispatchFuncOp) { + auto loc = exportOp.getLoc(); + auto builder = OpBuilder::atBlockBegin(dispatchFuncOp.addEntryBlock()); + IndexSet indexSet(loc, builder); + auto bufferType = builder.getType<IREE::Util::BufferType>(); + + SmallVector<Value> workgroupArgs; + + // Calculate the XYZ workgroup count from the export function. + // There may be multiple exports pointing at the same body with different + // workgroup count functions. + unsigned workloadArgCount = + exportOp.getWorkgroupCountBody()->getNumArguments() - 1; + unsigned argOffset = 0; + SmallVector<Value> workload; + workload.reserve(workloadArgCount); + for (unsigned i = 0; i < workloadArgCount; ++i) { + workload.push_back(dispatchFuncOp.getArgument(argOffset++)); + } + Value device = builder.create<IREE::Util::NullOp>( + loc, builder.getType<IREE::HAL::DeviceType>()); + auto workgroupCount = + exportOp.calculateWorkgroupCount(loc, device, workload, builder); + + // For now we don't handle local memory. + Value localMemory = builder.create<IREE::Util::NullOp>(loc, bufferType); + workgroupArgs.push_back(localMemory); + + // Pass all constants through. + for (int64_t i = 0; i < layoutAttr.getPushConstants(); ++i) { + workgroupArgs.push_back(dispatchFuncOp.getArgument(argOffset++)); + } + + // Pass all buffers through as subspans with the binding offset and length + // factored in. IPO can propagate the subspans (hopefully). + for (size_t i = 0; i < totalBindingCount; ++i) { + auto bindingBuffer = dispatchFuncOp.getArgument(argOffset + i); + auto bindingOffset = + dispatchFuncOp.getArgument(argOffset + totalBindingCount + i); + auto bindingLength = dispatchFuncOp.getArgument( + argOffset + totalBindingCount + totalBindingCount + i); + Value bufferSize = + builder.create<IREE::Util::BufferSizeOp>(loc, bindingBuffer); + Value bindingView = builder.create<IREE::Util::BufferSubspanOp>( + loc, bindingBuffer, bufferSize, bindingOffset, bindingLength); + workgroupArgs.push_back(bindingView); + } + + int workgroupXYZOffset = workgroupArgs.size(); + workgroupArgs.push_back(nullptr); // workgroup_x, set below + workgroupArgs.push_back(nullptr); // workgroup_y, set below + workgroupArgs.push_back(nullptr); // workgroup_z, set below + workgroupArgs.append(3, indexSet.get(1)); // workgroup_size_xyz + workgroupArgs.push_back(workgroupCount[0]); // workgroup_count_x + workgroupArgs.push_back(workgroupCount[1]); // workgroup_count_y + workgroupArgs.push_back(workgroupCount[2]); // workgroup_count_z + + // Z -> Y -> Z loop nest. + builder.create<scf::ForOp>( + loc, indexSet.get(0), workgroupCount[2], indexSet.get(1), ValueRange{}, + [&](OpBuilder &forZBuilder, Location loc, Value iz, ValueRange iters) { + workgroupArgs[workgroupXYZOffset + 2] = iz; + forZBuilder.create<scf::ForOp>( + loc, indexSet.get(0), workgroupCount[1], indexSet.get(1), + ValueRange{}, + [&](OpBuilder &forYBuilder, Location loc, Value iy, + ValueRange iters) { + workgroupArgs[workgroupXYZOffset + 1] = iy; + forYBuilder.create<scf::ForOp>( + loc, indexSet.get(0), workgroupCount[0], indexSet.get(1), + ValueRange{}, + [&](OpBuilder &forXBuilder, Location loc, Value ix, + ValueRange iters) { + workgroupArgs[workgroupXYZOffset + 0] = ix; + forXBuilder.create<func::CallOp>(loc, bodyFuncOp, + workgroupArgs); + forXBuilder.create<scf::YieldOp>(loc); + }); + forYBuilder.create<scf::YieldOp>(loc); + }); + forZBuilder.create<scf::YieldOp>(loc); + }); + + builder.create<func::ReturnOp>(loc); + } +}; + +std::unique_ptr<OperationPass<mlir::ModuleOp>> createInlineExecutablesPass() { + return std::make_unique<InlineExecutablesPass>(); +} + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/PassDetail.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/PassDetail.h new file mode 100644 index 0000000..55e647e --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/PassDetail.h
@@ -0,0 +1,29 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASS_DETAIL_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASS_DETAIL_H_ + +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +#define GEN_PASS_CLASSES +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h.inc" + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASS_DETAIL_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.cpp new file mode 100644 index 0000000..ddd9f4c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.cpp
@@ -0,0 +1,115 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h" + +#include <memory> + +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "iree/compiler/Utils/PassUtils.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +using FunctionLikeNest = MultiOpNest<func::FuncOp, IREE::Util::InitializerOp>; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static void addCleanupPatterns(OpPassManager &passManager) { + // Standard MLIR cleanup. + passManager.addPass(mlir::createCanonicalizerPass()); + passManager.addPass(mlir::createCSEPass()); + + FunctionLikeNest(passManager) + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + + // Cleanup and canonicalization of util.global (and other util ops). + passManager.addPass(IREE::Util::createApplyPatternsPass()); + passManager.addPass(IREE::Util::createFoldGlobalsPass()); + passManager.addPass(IREE::Util::createFuseGlobalsPass()); +} + +//===----------------------------------------------------------------------===// +// -iree-hal-inline-static-transformation-pipeline +//===----------------------------------------------------------------------===// + +void buildHALInlineStaticTransformPassPipeline( + OpPassManager &passManager, const TargetOptions &targetOptions) { + //---------------------------------------------------------------------------- + // Device assignment and interface materialization + //---------------------------------------------------------------------------- + + IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetOptions); + + //---------------------------------------------------------------------------- + // Executable translation + //---------------------------------------------------------------------------- + + // Translate each executable down to common MLIR dialects. + passManager.addNestedPass<IREE::HAL::ExecutableOp>( + IREE::HAL::createTranslateExecutablesPass()); + + // Inline the translated executable functions. + // We preserve the executables for their metadata used during conversion. + passManager.addPass(IREE::HAL::Inline::createInlineExecutablesPass()); + addCleanupPatterns(passManager); + + //---------------------------------------------------------------------------- + // Conversion + //---------------------------------------------------------------------------- + + // Convert from stream to hal_inline. + passManager.addPass(IREE::HAL::Inline::createConversionPass()); + + // Propagate buffer subranges across the program. + passManager.addPass(IREE::Util::createPropagateSubrangesPass()); + + //---------------------------------------------------------------------------- + // Cleanup and canonicalization + //---------------------------------------------------------------------------- + + addCleanupPatterns(passManager); +} + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h.inc" +} // namespace + +void registerHALInlinePasses() { + // Generated. + registerPasses(); + + static PassPipelineRegistration<> transformPassPipeline( + "iree-hal-inline-static-transformation-pipeline", + "Runs the inline HAL dialect transformation pipeline", + [](OpPassManager &passManager) { + buildHALInlineStaticTransformPassPipeline( + passManager, TargetOptions::FromFlags::get()); + }); +} + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h new file mode 100644 index 0000000..b1247fd --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h
@@ -0,0 +1,63 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASSES_H_ + +#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineOps.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace Inline { + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +// Adds a set of passes to the given pass manager that run the required +// HALInline transforms in the canonical order. +// +// Most translation code should prefer to use this instead of manually adding +// the passes themselves to ensure that expected pass ordering is observed. +// +// The expected usage is: +// <run conversion from TF/HLO/etc to flow> +// buildHALInlineTransformPassPipeline & run +// <serialize VM module> +void buildHALInlineStaticTransformPassPipeline( + OpPassManager &passManager, const TargetOptions &targetOptions); + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +// Inlines translated executable functions into the host program. +std::unique_ptr<OperationPass<mlir::ModuleOp>> createInlineExecutablesPass(); + +// Converts from the stream dialect into the hal_inline dialect. +std::unique_ptr<OperationPass<mlir::ModuleOp>> createConversionPass(); + +//===----------------------------------------------------------------------===// +// Register all Passes +//===----------------------------------------------------------------------===// + +void registerHALInlinePasses(); + +} // namespace Inline +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_MODULES_HAL_INLINE_TRANSFORMS_PASSES_H_
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.td new file mode 100644 index 0000000..60d2cb3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.td
@@ -0,0 +1,22 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_MODULES_HAL_INLINE_PASSES +#define IREE_MODULES_HAL_INLINE_PASSES + +include "mlir/Pass/PassBase.td" + +def Conversion : Pass<"iree-hal-inline-conversion", "mlir::ModuleOp"> { + let summary = "Converts from various dialects to the HAL inline dialect"; + let constructor = "mlir::iree_compiler::IREE::HAL::Inline::createConversionPass()"; +} + +def InlineExecutables : Pass<"iree-hal-inline-executables", "mlir::ModuleOp"> { + let summary = "Inlines translated executable functions into the host program"; + let constructor = "mlir::iree_compiler::IREE::HAL::Inline::createInlineExecutablesPass()"; +} + +#endif // IREE_MODULES_HAL_INLINE_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/BUILD new file mode 100644 index 0000000..a0e6f73 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "inline_executables.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/CMakeLists.txt new file mode 100644 index 0000000..f32aeb3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/CMakeLists.txt
@@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "inline_executables.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/inline_executables.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/inline_executables.mlir new file mode 100644 index 0000000..74dea66 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
@@ -0,0 +1,184 @@ +// RUN: iree-opt --split-input-file --iree-hal-inline-executables %s | FileCheck %s + +// Tests that exported dispatch functions get placed into the module with +// wrapper functions that perform the dispatch and all dispatch sites are tagged +// with the wrapper function. + +// CHECK-NOT: hal.executable +hal.executable private @ex { + hal.executable.variant public @vmvx_ir, target = <"vmvx-inline", "vmvx-ir"> { + hal.executable.export public @dispatch_0 ordinal(0) layout( + #hal.executable.layout<push_constants = 2, + sets = [ + <0, bindings = [ + <0, storage_buffer>, + <1, storage_buffer>, + <2, storage_buffer> + ]> + ]>) { + ^bb0(%arg0: !hal.device, %workload_x: index, %workload_y: index): + %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%workload_x] + %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%workload_y] + %count_z = arith.constant 1 : index + hal.return %count_x, %count_y, %count_z : index, index, index + } + builtin.module { + util.global private @global_constant : !util.buffer + util.initializer { + %buffer_cst = util.buffer.constant : !util.buffer = dense<[1, 2, 3, 4, 5]> : tensor<5xi32> + util.global.store %buffer_cst, @global_constant : !util.buffer + util.initializer.return + } + func.func @dispatch_0( + %local_memory: !util.buffer, + %constants: !util.buffer, + %bindings: !util.list<!util.buffer>, + %workgroup_x: index, %workgroup_y: index, %workgroup_z: index, + %workgroup_size_x: index, %workgroup_size_y: index, %workgroup_size_z: index, + %workgroup_count_x: index, %workgroup_count_y: index, %workgroup_count_z: index) { + // Unpack push constants: + %constants_size = util.buffer.size %constants : !util.buffer + %constant1_offset = arith.constant 4 : index + %constant1_i32 = util.buffer.load %constants[%constant1_offset] : !util.buffer{%constants_size} -> i32 + %constant1_f32 = arith.sitofp %constant1_i32 : i32 to f32 + + // Unpack buffer bindings: + %c0 = arith.constant 0 : index + %buffer0 = util.list.get %bindings[%c0] : !util.list<!util.buffer> + %c1 = arith.constant 1 : index + %buffer1 = util.list.get %bindings[%c1] : !util.list<!util.buffer> + %c2 = arith.constant 2 : index + %buffer2 = util.list.get %bindings[%c2] : !util.list<!util.buffer> + %buffer0_size = util.buffer.size %buffer0 : !util.buffer + %buffer1_size = util.buffer.size %buffer1 : !util.buffer + %buffer2_size = util.buffer.size %buffer2 : !util.buffer + + // Test for global constants: + %global_constant = util.global.load @global_constant : !util.buffer + util.do_not_optimize(%global_constant) : !util.buffer + + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %workgroup_x step %c1 { + %idx = arith.muli %i, %c4 : index + %lhs = util.buffer.load %buffer0[%idx] : !util.buffer{%buffer0_size} -> f32 + %rhs = util.buffer.load %buffer1[%idx] : !util.buffer{%buffer1_size} -> f32 + %mul = arith.mulf %lhs, %rhs : f32 + %scaled = arith.mulf %mul, %constant1_f32 : f32 + util.buffer.store %scaled, %buffer2[%idx] : f32 -> !util.buffer{%buffer2_size} + } + return + } + } + } +} + +// Ensures that we properly rename the globals we inline: +util.global private @global_constant : i32 + +// CHECK: util.global private @global_constant_0 : !util.buffer +// CHECK: util.initializer +// CHECK: %[[CONSTANT:.+]] = util.buffer.constant +// CHECK: util.global.store %[[CONSTANT]], @global_constant + +// Ensures that we properly rename the dispatch function we inline: +func.func private @dispatch_0() + +// CHECK-LABEL: func private @dispatch_0_0 +// CHECK-SAME: (%[[LOCAL_MEMORY:.+]]: !util.buffer, %[[CONSTANT0:.+]]: i32, %[[CONSTANT1:.+]]: i32, +// CHECK-SAME: %[[BINDING0:.+]]: !util.buffer, %[[BINDING1:.+]]: !util.buffer, %[[BINDING2:.+]]: !util.buffer, +// CHECK-SAME: %[[X:[a-z0-9]+]]: index, %[[Y:[a-z0-9]+]]: index, %[[Z:[a-z0-9]+]]: index, +// CHECK-SAME: %[[SIZE_XYZ:[a-z0-9]+]]: index, %[[SIZE_XYZ:[a-z0-9]+]]: index, %[[SIZE_XYZ:[a-z0-9]+]]: index, +// CHECK-SAME: %[[COUNT_X:[a-z0-9]+]]: index, %[[COUNT_Y:[a-z0-9]+]]: index, %[[COUNT_Z:[a-z0-9]+]]: index) + +// Push constant rewritten to use args: +// CHECK: %[[CONSTANT1_F32:.+]] = arith.sitofp %[[CONSTANT1]] : i32 to f32 + +// Bindings get changed to use args: +// CHECK: %[[BINDING0_SIZE:.+]] = util.buffer.size %[[BINDING0]] +// CHECK: %[[BINDING1_SIZE:.+]] = util.buffer.size %[[BINDING1]] +// CHECK: %[[BINDING2_SIZE:.+]] = util.buffer.size %[[BINDING2]] + +// Globals get carried across: +// CHECK: %[[GLOBAL_CONSTANT:.+]] = util.global.load @global_constant_0 : !util.buffer +// CHECK: util.do_not_optimize(%[[GLOBAL_CONSTANT]]) + +// CHECK: scf.for %[[ELEMENT_INDEX:.+]] = %c0 to %[[X]] +// CHECK: %[[ELEMENT_OFFSET:.+]] = arith.muli %[[ELEMENT_INDEX]] +// CHECK: %[[LHS:.+]] = util.buffer.load %[[BINDING0]][%[[ELEMENT_OFFSET]]] : !util.buffer{%[[BINDING0_SIZE]]} -> f32 +// CHECK: %[[RHS:.+]] = util.buffer.load %[[BINDING1]][%[[ELEMENT_OFFSET]]] : !util.buffer{%[[BINDING1_SIZE]]} -> f32 +// CHECK: %[[MUL:.+]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: %[[SCALED:.+]] = arith.mulf %[[MUL]], %[[CONSTANT1_F32]] : f32 +// CHECK: util.buffer.store %[[SCALED]], %[[BINDING2]][%[[ELEMENT_OFFSET]]] : f32 -> !util.buffer{%[[BINDING2_SIZE]]} +// CHECK: return + +// CHECK-LABEL: func private @__dispatch_ex_dispatch_0 +// CHECK-SAME: (%[[WORKLOAD_X:.+]]: index, %[[WORKLOAD_Y:.+]]: index, %[[CONSTANT0:.+]]: i32, %[[CONSTANT1:.+]]: i32, +// CHECK-SAME: %[[BINDING0:.+]]: !util.buffer, %[[BINDING1:.+]]: !util.buffer, %[[BINDING2:.+]]: !util.buffer, +// CHECK-SAME: %[[OFFSET0:[a-z0-9]+]]: index, %[[OFFSET1:[a-z0-9]+]]: index, %[[OFFSET2:[a-z0-9]+]]: index, +// CHECK-SAME: %[[LENGTH0:.+]]: index, %[[LENGTH1:.+]]: index, %[[LENGTH2:.+]]: index) + +// Inlined workgroup count calculation from the export op: +// CHECK: %[[COUNT_X:.+]] = affine.apply {{.+}}[%[[WORKLOAD_X]]] +// CHECK: %[[COUNT_Y:.+]] = affine.apply {{.+}}[%[[WORKLOAD_Y]]] +// CHECK: %[[COUNT_Z:.+]] = arith.constant 1 + +// Local workgroup memory not currently used: +// CHECK: %[[LOCAL_MEMORY:.+]] = util.null : !util.buffer + +// Binding subspans as specified on the dispatch: +// CHECK: %[[BINDING0_SIZE:.+]] = util.buffer.size %[[BINDING0]] +// CHECK: %[[BINDING0_SPAN:.+]] = util.buffer.subspan %[[BINDING0]][%[[OFFSET0]]] : !util.buffer{%[[BINDING0_SIZE]]} -> !util.buffer{%[[LENGTH0]]} +// CHECK: %[[BINDING1_SIZE:.+]] = util.buffer.size %[[BINDING1]] +// CHECK: %[[BINDING1_SPAN:.+]] = util.buffer.subspan %[[BINDING1]][%[[OFFSET1]]] : !util.buffer{%[[BINDING1_SIZE]]} -> !util.buffer{%[[LENGTH1]]} +// CHECK: %[[BINDING2_SIZE:.+]] = util.buffer.size %[[BINDING2]] +// CHECK: %[[BINDING2_SPAN:.+]] = util.buffer.subspan %[[BINDING2]][%[[OFFSET2]]] : !util.buffer{%[[BINDING2_SIZE]]} -> !util.buffer{%[[LENGTH2]]} + +// Workgroup XYZ loop: +// CHECK: %[[SIZE_XYZ:.+]] = arith.constant 1 +// CHECK: scf.for %[[Z:.+]] = %c0 to %[[COUNT_Z]] +// CHECK: scf.for %[[Y:.+]] = %c0 to %[[COUNT_Y]] +// CHECK: scf.for %[[X:.+]] = %c0 to %[[COUNT_X]] +// CHECK: func.call @dispatch_0_0( +// CHECK-SAME: %[[LOCAL_MEMORY]], +// CHECK-SAME: %[[CONSTANT0]], %[[CONSTANT1]], +// CHECK-SAME: %[[BINDING0_SPAN]], %[[BINDING1_SPAN]], %[[BINDING2_SPAN]], +// CHECK-SAME: %[[X]], %[[Y]], %[[Z]], +// CHECK-SAME: %[[SIZE_XYZ]], %[[SIZE_XYZ]], %[[SIZE_XYZ]], +// CHECK-SAME: %[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]) +// CHECK: return + +// CHECK-LABEL: @dispatch0 +// CHECK-SAME: (%[[RESOURCE0:.+]]: !stream.resource<constant>, +// CHECK-SAME: %[[RESOURCE1:.+]]: !stream.resource<transient>, +// CHECK-SAME: %[[RESOURCE2:.+]]: !stream.resource<external>) +func.func private @dispatch0(%resource0: !stream.resource<constant>, %resource1: !stream.resource<transient>, %resource2: !stream.resource<external>) { + %workload_x = arith.constant 1000 : index + %workload_y = arith.constant 1001 : index + %constant0 = arith.constant 4 : i32 + %constant1 = arith.constant 5 : i32 + %binding0_offset = arith.constant 200 : index + %binding0_length = arith.constant 128 : index + %binding1_offset = arith.constant 300 : index + %binding1_length = arith.constant 256 : index + %binding2_offset = arith.constant 400 : index + %binding2_length = arith.constant 512 : index + %resource_size = arith.constant 1024 : index + %0 = stream.cmd.execute with(%resource0 as %resource0_inner: !stream.resource<constant>{%resource_size}, + %resource1 as %resource1_inner: !stream.resource<transient>{%resource_size}, + %resource2 as %resource2_inner: !stream.resource<external>{%resource_size}) { + // CHECK: stream.cmd.dispatch + // CHECK: hal_inline.target = @__dispatch_ex_dispatch_0 + stream.cmd.dispatch @ex::@dispatch_0[%workload_x, %workload_y](%constant0, %constant1 : i32, i32) { + ro %resource0_inner[%binding0_offset for %binding0_length] : !stream.resource<constant>{%resource_size}, + ro %resource1_inner[%binding1_offset for %binding1_length] : !stream.resource<transient>{%resource_size}, + wo %resource2_inner[%binding2_offset for %binding2_length] : !stream.resource<external>{%resource_size} + } attributes { + hal.interface.bindings = [ + #hal.interface.binding<0, 0>, + #hal.interface.binding<0, 1>, + #hal.interface.binding<0, 2> + ] + } + } => !stream.timepoint + return +}
diff --git a/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/hal_inline.imports.mlir b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/hal_inline.imports.mlir new file mode 100644 index 0000000..e9ee06d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/hal_inline.imports.mlir
@@ -0,0 +1,131 @@ +// IREE Inline Hardware Abstraction Layer (HAL) runtime module imports. +// This is only used to provide ABI compatibility with the full HAL module and +// user programs that use !hal.buffer/!hal.buffer_view as IO. +// +// This is embedded in the compiler binary and inserted into any module +// containing inline HAL dialect ops (hal_inline.*) that is lowered to the VM +// dialect. +vm.module @hal_inline { + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_t +//===----------------------------------------------------------------------===// + +// Allocates an empty buffer. +vm.import @buffer.allocate( + %minimum_alignment : i32, + %allocation_size : i64 +) -> (!vm.ref<!hal.buffer>, !vm.buffer) +attributes {nosideeffects} + +// Allocates a buffer with an initial value provided by a VM byte buffer. +vm.import @buffer.allocate.initialized( + %minimum_alignment : i32, + %source : !vm.buffer, + %offset : i64, + %length : i64 +) -> (!vm.ref<!hal.buffer>, !vm.buffer) +attributes {nosideeffects} + +// Wraps a VM byte buffer in a HAL buffer. +vm.import @buffer.wrap( + %source : !vm.buffer, + %offset : i64, + %length : i64 +) -> !vm.ref<!hal.buffer> +attributes {nosideeffects} + +// Returns a reference to a subspan of the buffer. +vm.import @buffer.subspan( + %source_buffer : !vm.ref<!hal.buffer>, + %source_offset : i64, + %length : i64 +) -> !vm.ref<!hal.buffer> +attributes {nosideeffects} + +// TODO(benvanik): make storage return length and remove dedicated length? + +// Returns the byte length of the buffer (may be less than total allocation). +vm.import @buffer.length( + %buffer : !vm.ref<!hal.buffer> +) -> i64 +attributes {nosideeffects} + +// Returns a mapping to the underlying storage of the buffer sliced to the +// logical subspan of the HAL buffer. +vm.import @buffer.storage( + %buffer : !vm.ref<!hal.buffer> +) -> !vm.buffer +attributes {nosideeffects} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_t +//===----------------------------------------------------------------------===// + +// Creates a reference to a buffer with a particular shape and element type. +vm.import @buffer_view.create( + %buffer : !vm.ref<!hal.buffer>, + %element_type : i32, + %encoding_type : i32, + %shape : i64 ... +) -> !vm.ref<!hal.buffer_view> +attributes {nosideeffects} + +// Asserts a buffer view matches the given tensor encoding and shape. +vm.import @buffer_view.assert( + %buffer_view : !vm.ref<!hal.buffer_view>, + %message : !vm.buffer, + %element_type : i32, + %encoding_type : i32, + %shape : i64 ... +) + +// Returns the backing buffer of the buffer view. +vm.import @buffer_view.buffer( + %buffer_view : !vm.ref<!hal.buffer_view> +) -> !vm.ref<!hal.buffer> +attributes {nosideeffects} + +// Returns the element type of the buffer view. +vm.import @buffer_view.element_type( + %buffer_view : !vm.ref<!hal.buffer_view>, +) -> i32 +attributes {nosideeffects} + +// Returns the encoding type of the buffer view. +vm.import @buffer_view.encoding_type( + %buffer_view : !vm.ref<!hal.buffer_view>, +) -> i32 +attributes {nosideeffects} + +// Returns the rank of the buffer view. +vm.import @buffer_view.rank( + %buffer_view : !vm.ref<!hal.buffer_view>, +) -> i32 +attributes {nosideeffects} + +// Returns the value of the given dimension. +vm.import @buffer_view.dim( + %buffer_view : !vm.ref<!hal.buffer_view>, + %index : i32 +) -> i64 +attributes {nosideeffects} + +// Prints out the content of buffer views. +vm.import @buffer_view.trace( + %key : !vm.buffer, + %operands : !vm.ref<!hal.buffer_view> ... +) + +//===----------------------------------------------------------------------===// +// iree_hal_device_t +//===----------------------------------------------------------------------===// + +// Returns a tuple of (ok, value) for the given configuration key. +vm.import @device.query.i64( + %category : !vm.buffer, + %key : !vm.buffer +) -> (i32, i64) +attributes {nosideeffects} + +} // module
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp index 7ac2522..11ccd4b 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
@@ -176,10 +176,13 @@ auto constantsArg = op->getParentOfType<mlir::func::FuncOp>().getArgument( kEntryArgConstants); assert(constantsArg && "entry point not conforming to requirements"); + // HACK: we could find the total push constant count and avoid this size op + // but it'd require walking all the way up to the hal.executable export. auto constantsSize = rewriter.create<IREE::Util::BufferSizeOp>(op.getLoc(), constantsArg); auto resultType = getTypeConverter()->convertType(op.getResult().getType()); + // Index -> byte offset. auto constantIndex = rewriter.createOrFold<arith::ConstantIndexOp>( op.getLoc(), op.getIndex().getZExtValue()); auto elementSize = @@ -254,7 +257,7 @@ .getResult(); if (op.getByteOffset() && !matchPattern(op.getByteOffset(), m_Zero())) { - // Offsetted binding: replace with a BufferSpan. + // Offsetted binding: replace with a BufferSubspanOp. Value sourceSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>( op.getLoc(), sourceBuffer);
diff --git a/compiler/src/iree/compiler/Pipelines/BUILD b/compiler/src/iree/compiler/Pipelines/BUILD index da71ce4..48030f7 100644 --- a/compiler/src/iree/compiler/Pipelines/BUILD +++ b/compiler/src/iree/compiler/Pipelines/BUILD
@@ -39,6 +39,7 @@ "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM", "//compiler/src/iree/compiler/Dialect/HAL/Transforms", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms", "//compiler/src/iree/compiler/Dialect/Stream/Transforms", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Dialect/VM/Conversion",
diff --git a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt index f91ccca..824566f 100644 --- a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt +++ b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
@@ -48,6 +48,7 @@ iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::Conversion::HALToVM iree::compiler::Dialect::HAL::Transforms + iree::compiler::Dialect::Modules::HAL::Inline::Transforms iree::compiler::Dialect::Stream::Transforms iree::compiler::Dialect::Util::Transforms iree::compiler::Dialect::VM::Conversion
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index a7825ff..1c68eb2 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -97,7 +97,10 @@ "internally but exporting functions as if synchronous."), clEnumValN(ExecutionModel::AsyncExternal, "async-external", "Full HAL using asynchronous host/device execution both " - "internally and externally.")), + "internally and externally."), + clEnumValN(ExecutionModel::InlineStatic, "inline-static", + "Inline host-local in-process execution with executable " + "code statically linked into the host program.")), llvm::cl::cat(category)); binder.opt<DumpOutputFormat>(
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index e048c51..405ee84 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -88,6 +88,10 @@ // Full HAL using asynchronous host/device execution both internally and // externally. AsyncExternal = 2, + // Inline host-local in-process execution with executable code statically + // linked into the host program. + // (Currently) only supports the `vmvx-inline` HAL target backend. + InlineStatic = 3, }; // Program execution model specifying scheduling behavior. ExecutionModel executionModel = ExecutionModel::AsyncInternal;
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 6bea87a..11e3c0e 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -10,6 +10,7 @@ #include "iree/compiler/Bindings/TFLite/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h" #include "iree/compiler/Dialect/Stream/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Transforms/Passes.h" @@ -127,6 +128,10 @@ case SchedulingOptions::ExecutionModel::AsyncExternal: IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions); break; + case SchedulingOptions::ExecutionModel::InlineStatic: + IREE::HAL::Inline::buildHALInlineStaticTransformPassPipeline( + passManager, executableOptions); + break; } IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
diff --git a/compiler/src/iree/compiler/Tools/BUILD b/compiler/src/iree/compiler/Tools/BUILD index 0ff9e15..2bff338 100644 --- a/compiler/src/iree/compiler/Tools/BUILD +++ b/compiler/src/iree/compiler/Tools/BUILD
@@ -50,6 +50,8 @@ "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Transforms", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/IR:HALInlineDialect", + "//compiler/src/iree/compiler/Dialect/Modules/HAL/Inline/Transforms", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Stream/Transforms", "//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index 531ddaa..43da3e5 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -93,8 +93,8 @@ iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Transforms - iree::compiler::Dialect::VMVX::IR::VMVXDialect - iree::compiler::Dialect::VMVX::Transforms + iree::compiler::Dialect::Modules::HAL::Inline::IR::HALInlineDialect + iree::compiler::Dialect::Modules::HAL::Inline::Transforms iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Stream::Transforms iree::compiler::Dialect::Util::IR @@ -104,6 +104,8 @@ iree::compiler::Dialect::VM::IR iree::compiler::Dialect::VM::Target::init_targets iree::compiler::Dialect::VM::Transforms + iree::compiler::Dialect::VMVX::IR::VMVXDialect + iree::compiler::Dialect::VMVX::Transforms iree::compiler::Dialect::Vulkan::IR iree::compiler::ConstEval iree::compiler::Pipelines
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index b2c1308..5cfa88e 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -21,6 +21,7 @@ #include "iree/compiler/Codegen/Interfaces/Interfaces.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/IR/HALInlineDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilExternalModels.h" @@ -38,6 +39,7 @@ registry.insert<IREE::Codegen::IREECodegenDialect, IREE::Flow::FlowDialect, IREE::HAL::HALDialect, + IREE::HAL::Inline::HALInlineDialect, IREE::LinalgExt::IREELinalgExtDialect, mlir::linalg::transform::LinalgTransformDialect, IREE::Stream::StreamDialect,
diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h index 8192cac..a01be67 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_passes.h +++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h
@@ -20,6 +20,7 @@ #include "iree/compiler/ConstEval/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Modules/HAL/Inline/Transforms/Passes.h" #include "iree/compiler/Dialect/Stream/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Analysis/TestPasses.h" @@ -46,6 +47,7 @@ ConstEval::registerConstEvalPasses(); IREE::Flow::registerFlowPasses(); IREE::HAL::registerHALPasses(); + IREE::HAL::Inline::registerHALInlinePasses(); IREE::LinalgExt::registerPasses(); IREE::Stream::registerStreamPasses(); IREE::Util::registerTransformPasses();
diff --git a/runtime/src/iree/modules/hal/inline/BUILD b/runtime/src/iree/modules/hal/inline/BUILD new file mode 100644 index 0000000..9b0994e --- /dev/null +++ b/runtime/src/iree/modules/hal/inline/BUILD
@@ -0,0 +1,34 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_runtime_cc_library( + name = "inline", + srcs = [ + "module.c", + ], + hdrs = [ + "module.h", + ], + textual_hdrs = [ + "exports.inl", + ], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/base:tracing", + "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal:types", + "//runtime/src/iree/modules/hal/utils:buffer_diagnostics", + "//runtime/src/iree/vm", + ], +)
diff --git a/runtime/src/iree/modules/hal/inline/CMakeLists.txt b/runtime/src/iree/modules/hal/inline/CMakeLists.txt new file mode 100644 index 0000000..64022f8 --- /dev/null +++ b/runtime/src/iree/modules/hal/inline/CMakeLists.txt
@@ -0,0 +1,32 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# runtime/src/iree/modules/hal/inline/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + inline + HDRS + "module.h" + TEXTUAL_HDRS + "exports.inl" + SRCS + "module.c" + DEPS + iree::base + iree::base::tracing + iree::hal + iree::modules::hal::types + iree::modules::hal::utils::buffer_diagnostics + iree::vm + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/modules/hal/inline/exports.inl b/runtime/src/iree/modules/hal/inline/exports.inl new file mode 100644 index 0000000..40f80ce --- /dev/null +++ b/runtime/src/iree/modules/hal/inline/exports.inl
@@ -0,0 +1,45 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████ +// ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██ +// ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███ +// ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ +// ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ +// +//===----------------------------------------------------------------------===// +// +// This file will be auto generated from hal_inline.imports.mlir in the future; +// for now it's modified by hand but with strict alphabetical sorting required. +// The order of these functions must be sorted ascending by name in a way +// compatible with iree_string_view_compare. +// +// Users are meant to `#define EXPORT_FN` to be able to access the information. +// #define EXPORT_FN(name, target_fn, arg_type, ret_type) + +// clang-format off + +EXPORT_FN("buffer.allocate", iree_hal_inline_module_buffer_allocate, iI, rr) +EXPORT_FN("buffer.allocate.initialized", iree_hal_inline_module_buffer_allocate_initialized, irII, rr) +EXPORT_FN("buffer.length", iree_hal_inline_module_buffer_length, r, I) +EXPORT_FN("buffer.storage", iree_hal_inline_module_buffer_storage, r, r) +EXPORT_FN("buffer.subspan", iree_hal_inline_module_buffer_subspan, rII, r) +EXPORT_FN("buffer.wrap", iree_hal_inline_module_buffer_wrap, rII, r) + +EXPORT_FN("buffer_view.assert", iree_hal_inline_module_buffer_view_assert, rriiCID, v) +EXPORT_FN("buffer_view.buffer", iree_hal_inline_module_buffer_view_buffer, r, r) +EXPORT_FN("buffer_view.create", iree_hal_inline_module_buffer_view_create, riiCID, r) +EXPORT_FN("buffer_view.dim", iree_hal_inline_module_buffer_view_dim, ri, I) +EXPORT_FN("buffer_view.element_type", iree_hal_inline_module_buffer_view_element_type, r, i) +EXPORT_FN("buffer_view.encoding_type", iree_hal_inline_module_buffer_view_encoding_type, r, i) +EXPORT_FN("buffer_view.rank", iree_hal_inline_module_buffer_view_rank, r, i) +EXPORT_FN("buffer_view.trace", iree_hal_inline_module_buffer_view_trace, rCrD, v) + +EXPORT_FN("device.query.i64", iree_hal_inline_module_device_query_i64, rr, iI) + +// clang-format on
diff --git a/runtime/src/iree/modules/hal/inline/module.c b/runtime/src/iree/modules/hal/inline/module.c new file mode 100644 index 0000000..65b460c --- /dev/null +++ b/runtime/src/iree/modules/hal/inline/module.c
@@ -0,0 +1,609 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/modules/hal/inline/module.h" + +#include "iree/base/api.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/utils/buffer_diagnostics.h" +#include "iree/vm/api.h" + +#define IREE_HAL_INLINE_MODULE_VERSION_0_0 0x00000000u +#define IREE_HAL_INLINE_MODULE_VERSION_LATEST IREE_HAL_INLINE_MODULE_VERSION_0_0 + +//===----------------------------------------------------------------------===// +// iree_hal_inline_storage_buffer_t +//===----------------------------------------------------------------------===// + +// Inlined VM buffer using a HAL buffer for storage. +// This uses the reference counting of the embedded VM buffer +// to track lifetime combined with a custom allocator to handle +// cleaning up this wrapper when the VM buffer is no longer referenced. +// +// Since the HAL buffer is providing the storage and the VM buffer is just +// pointing into it the critical thing this wrapper does is ensure the HAL +// buffer always outlives the VM buffer. +// +// NOTE: this is allocated each storage query! The assumption is that the +// returned buffer is long-lived (at least per-invocation). This is primarily +// used to get the backing storage of a !hal.buffer that a user passes into an +// invocation and the compiler should CSE such queries. Since users can provide +// their own allocators they can decide if they want to pool small allocations +// to bypass the system allocator. If we wanted to in here we could have a small +// free list we maintained for this purpose at the cost of fixed memory +// consumption. Note that the key requirement is that the returned VM buffer +// may outlive the module so we can't use an arena that has module lifetime. +typedef struct iree_hal_inline_storage_buffer_t { + // Allocator used to allocate this storage buffer. + iree_allocator_t host_allocator; + // HAL buffer backing this storage buffer. + // Retained for the lifetime of this instance so that the + // wrapped vm_buffer is always valid. + iree_hal_buffer_t* hal_buffer; + // Scoped mapping into the buffer. We could make it persistent but because + // we can trivially scope things having this extra information is cheap and + // useful for debugging. + iree_hal_buffer_mapping_t mapping; + // Inline initialized VM buffer wrapping the hal_buffer storage. + // This directly references the memory of the HAL buffer. + // The buffer has a custom allocator that calls back into this + // struct to deallocate the wrapper. + iree_vm_buffer_t vm_buffer; +} iree_hal_inline_storage_buffer_t; + +static void iree_hal_inline_storage_buffer_destroy( + iree_hal_inline_storage_buffer_t* storage); + +static iree_status_t iree_hal_inline_storage_buffer_ctl( + void* self, iree_allocator_command_t command, const void* params, + void** inout_ptr) { + if (command != IREE_ALLOCATOR_COMMAND_FREE) { + return iree_make_status( + IREE_STATUS_FAILED_PRECONDITION, + "allocator can only be used for dropping the wrapper buffer"); + } + iree_hal_inline_storage_buffer_t* storage = + (iree_hal_inline_storage_buffer_t*)self; + iree_hal_inline_storage_buffer_destroy(storage); + return iree_ok_status(); +} + +// Creates a VM buffer wrapper that directly references HAL buffer storage. +// The returned |out_vm_buffer| lifetime will extend the HAL buffer lifetime. +static iree_status_t iree_hal_inline_storage_buffer_create( + iree_hal_buffer_t* hal_buffer, iree_allocator_t host_allocator, + iree_vm_buffer_t** out_vm_buffer) { + IREE_ASSERT_ARGUMENT(hal_buffer); + IREE_ASSERT_ARGUMENT(out_vm_buffer); + *out_vm_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Allocate zero-initialized storage wrapper. + iree_hal_inline_storage_buffer_t* storage = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*storage), + (void**)&storage)); + + // Map the HAL buffer into host-accessible memory. It almost always is but + // it's possible the buffer we were passed was allocated on a real device that + // requires mapping. + iree_status_t status = iree_hal_buffer_map_range( + hal_buffer, IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_ANY, 0, + IREE_WHOLE_BUFFER, &storage->mapping); + + // Initializes the VM buffer to reference the mapped memory. + // Since the VM buffer is what we pass back to the VM and gets reference + // counted we pass a custom allocator that lets us know when the VM (or + // user) is done with it. + if (iree_status_is_ok(status)) { + iree_allocator_t self_allocator = { + .self = storage, + .ctl = iree_hal_inline_storage_buffer_ctl, + }; + iree_vm_buffer_initialize( + IREE_VM_BUFFER_ACCESS_ORIGIN_HOST | IREE_VM_BUFFER_ACCESS_MUTABLE, + storage->mapping.contents, self_allocator, &storage->vm_buffer); + } + + if (iree_status_is_ok(status)) { + *out_vm_buffer = &storage->vm_buffer; + } else { + iree_hal_inline_storage_buffer_destroy(storage); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_inline_storage_buffer_destroy( + iree_hal_inline_storage_buffer_t* storage) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_allocator_t host_allocator = storage->host_allocator; + iree_hal_buffer_unmap_range(&storage->mapping); + iree_hal_buffer_release(storage->hal_buffer); + iree_allocator_free(host_allocator, storage); + IREE_TRACE_ZONE_END(z0); +} + +//===----------------------------------------------------------------------===// +// Module type definitions +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_inline_module_t { + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + iree_hal_inline_module_flags_t flags; + // TODO(benvanik): types. +} iree_hal_inline_module_t; + +#define IREE_HAL_INLINE_MODULE_CAST(module) \ + (iree_hal_inline_module_t*)((uint8_t*)(module) + \ + iree_vm_native_module_size()); + +typedef struct iree_hal_inline_module_state_t { + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + iree_hal_inline_module_flags_t flags; +} iree_hal_inline_module_state_t; + +static void IREE_API_PTR iree_hal_inline_module_destroy(void* base_module) { + iree_hal_inline_module_t* module = IREE_HAL_INLINE_MODULE_CAST(base_module); + iree_hal_allocator_release(module->device_allocator); + module->device_allocator = NULL; +} + +static iree_status_t IREE_API_PTR +iree_hal_inline_module_alloc_state(void* self, iree_allocator_t host_allocator, + iree_vm_module_state_t** out_module_state) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_inline_module_t* module = IREE_HAL_INLINE_MODULE_CAST(self); + iree_hal_inline_module_state_t* state = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state)); + memset(state, 0, sizeof(*state)); + state->host_allocator = host_allocator; + state->device_allocator = module->device_allocator; + iree_hal_allocator_retain(state->device_allocator); + state->flags = module->flags; + + *out_module_state = (iree_vm_module_state_t*)state; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void IREE_API_PTR iree_hal_inline_module_free_state( + void* self, iree_vm_module_state_t* module_state) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_inline_module_state_t* state = + (iree_hal_inline_module_state_t*)module_state; + iree_hal_allocator_release(state->device_allocator); + state->device_allocator = NULL; + iree_allocator_free(state->host_allocator, state); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t IREE_API_PTR iree_hal_inline_module_notify( + void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) { + switch (signal) { + case IREE_VM_SIGNAL_SUSPEND: + case IREE_VM_SIGNAL_LOW_MEMORY: + default: + return iree_ok_status(); + } +} + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Casts a VM value to a C host size. +static iree_host_size_t iree_hal_cast_host_size(int64_t value) { + // TODO(benvanik): make this return status and check for overflow if host + // size is 32-bits. + return (iree_host_size_t)value; +} + +// Casts a VM value to a HAL device size. +static iree_device_size_t iree_hal_cast_device_size(int64_t value) { + // TODO(benvanik): make this return status and check for overflow if device + // size is 32-bits. + return (iree_device_size_t)value; +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_t +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_inline_module_buffer_allocate_with_storage( + iree_hal_allocator_t* device_allocator, iree_hal_buffer_params_t params, + iree_device_size_t allocation_size, iree_const_byte_span_t initial_data, + iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer, + iree_vm_buffer_t** out_storage) { + // We could optimize this to create both at the same time and avoid the extra + // storage allocation by having a custom iree_hal_buffer_t type or a way to + // allocate additional data in the iree_hal_buffer_params_t that we stashed + // the storage in. Today this is all intentionally simple and something we can + // change in the runtime without impacting the compiler/artifacts. + + // Try allocating the buffer first + iree_hal_buffer_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + device_allocator, params, allocation_size, initial_data, &buffer)); + + // Map and retain the HAL buffer and return a VM buffer that is usable as if + // it were a native iree_vm_buffer_t. + iree_vm_buffer_t* storage = NULL; + iree_status_t status = + iree_hal_inline_storage_buffer_create(buffer, host_allocator, &storage); + if (!iree_status_is_ok(status)) { + iree_hal_buffer_release(buffer); + return status; + } + + *out_buffer = buffer; + *out_storage = storage; + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_allocate, // + iree_hal_inline_module_state_t, // + iI, rr) { + iree_device_size_t minimum_alignment = iree_hal_cast_device_size(args->i0); + iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i1); + + const iree_hal_buffer_params_t params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | + IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST, + .min_alignment = minimum_alignment, + }; + iree_hal_buffer_t* buffer = NULL; + iree_vm_buffer_t* storage = NULL; + IREE_RETURN_IF_ERROR(iree_hal_inline_module_buffer_allocate_with_storage( + state->device_allocator, params, allocation_size, + iree_const_byte_span_empty(), state->host_allocator, &buffer, &storage)); + + rets->r0 = iree_hal_buffer_move_ref(buffer); + rets->r1 = iree_vm_buffer_move_ref(storage); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_allocate_initialized, // + iree_hal_inline_module_state_t, // + irII, rr) { + iree_device_size_t minimum_alignment = iree_hal_cast_device_size(args->i0); + iree_vm_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer)); + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i2); + iree_device_size_t source_length = iree_hal_cast_device_size(args->i3); + + iree_const_byte_span_t initial_data = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(source_buffer, source_offset, + source_length, 1, &initial_data)); + + const iree_hal_buffer_params_t params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | + IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST, + .min_alignment = minimum_alignment, + }; + iree_hal_buffer_t* buffer = NULL; + iree_vm_buffer_t* storage = NULL; + IREE_RETURN_IF_ERROR(iree_hal_inline_module_buffer_allocate_with_storage( + state->device_allocator, params, source_length, initial_data, + state->host_allocator, &buffer, &storage)); + + rets->r0 = iree_hal_buffer_move_ref(buffer); + rets->r1 = iree_vm_buffer_move_ref(storage); + return iree_ok_status(); + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_wrap, // + iree_hal_inline_module_state_t, // + rII, r) { + iree_vm_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r0, &source_buffer)); + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1); + iree_device_size_t source_length = iree_hal_cast_device_size(args->i2); + + // TODO(benvanik): implement buffer wrapping. + // We don't emit this on the compiler today but could if we wanted to return + // constants/variables from the program without copies. + // + // We could do this by having a custom iree_hal_buffer_t type that retains + // the vm buffer, like `iree_hal_external_vm_buffer_t`. + // We may then want to expose this wrap method on the public module API so + // that users can pass in buffers like this. + // + // hal_inline.buffer.storage would need to switch based on type and return + // the underlying wrapped vm.buffer. + (void)source_buffer; + (void)source_offset; + (void)source_length; + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "vm->hal buffer wrapping not yet implemented"); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_subspan, // + iree_hal_inline_module_state_t, // + rII, r) { + iree_hal_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer)); + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1); + iree_device_size_t length = iree_hal_cast_device_size(args->i2); + + iree_hal_buffer_t* subspan_buffer = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_subspan(source_buffer, source_offset, length, + &subspan_buffer), + "invalid subspan of an existing buffer (source_offset=%" PRIdsz + ", length=%" PRIdsz ")", + source_offset, length); + + rets->r0 = iree_hal_buffer_move_ref(subspan_buffer); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_length, // + iree_hal_inline_module_state_t, // + r, I) { + iree_hal_buffer_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer)); + rets->i0 = (int64_t)iree_hal_buffer_byte_length(buffer); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_storage, // + iree_hal_inline_module_state_t, // + r, r) { + iree_hal_buffer_t* hal_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &hal_buffer)); + + // Map and retain the HAL buffer and return a VM buffer that is usable as if + // it were a native iree_vm_buffer_t. + iree_vm_buffer_t* vm_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_inline_storage_buffer_create( + hal_buffer, state->host_allocator, &vm_buffer)); + + rets->r0 = iree_vm_buffer_move_ref(vm_buffer); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_t +//===----------------------------------------------------------------------===// + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_create, // + iree_hal_inline_module_state_t, // + riiCID, r) { + iree_hal_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer)); + iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i1; + iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i2; + iree_host_size_t shape_rank = 0; + iree_hal_dim_t* shape_dims = NULL; + // TODO(benvanik): avoid the cast/alloca if not required. + IREE_VM_ABI_VLA_STACK_CAST(args, a3_count, a3, iree_hal_dim_t, 128, + &shape_rank, &shape_dims); + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( + source_buffer, shape_rank, shape_dims, element_type, encoding_type, + state->host_allocator, &buffer_view)); + rets->r0 = iree_hal_buffer_view_move_ref(buffer_view); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_assert, // + iree_hal_inline_module_state_t, // + rriiCID, v) { + iree_host_size_t expected_shape_rank = 0; + iree_hal_dim_t* expected_shape_dims = NULL; + // TODO(benvanik): avoid the cast/alloca if not required. + IREE_VM_ABI_VLA_STACK_CAST(args, a4_count, a4, iree_hal_dim_t, 128, + &expected_shape_rank, &expected_shape_dims); + return iree_hal_modules_buffer_view_assert( + args->r0, args->r1, (iree_hal_element_type_t)args->i2, + (iree_hal_encoding_type_t)args->i3, expected_shape_rank, + expected_shape_dims); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_buffer, // + iree_hal_inline_module_state_t, // + r, r) { + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_check_deref(args->r0, &buffer_view)); + rets->r0 = + iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view)); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_element_type, // + iree_hal_inline_module_state_t, // + r, i) { + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_check_deref(args->r0, &buffer_view)); + rets->i0 = (uint32_t)iree_hal_buffer_view_element_type(buffer_view); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_encoding_type, // + iree_hal_inline_module_state_t, // + r, i) { + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_check_deref(args->r0, &buffer_view)); + rets->i0 = (uint32_t)iree_hal_buffer_view_encoding_type(buffer_view); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_rank, // + iree_hal_inline_module_state_t, // + r, i) { + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_check_deref(args->r0, &buffer_view)); + rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_rank(buffer_view); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_dim, // + iree_hal_inline_module_state_t, // + ri, I) { + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_check_deref(args->r0, &buffer_view)); + iree_vm_size_t index = (iree_vm_size_t)args->i1; + rets->i0 = (int64_t)iree_hal_buffer_view_shape_dim(buffer_view, index); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_trace, // + iree_hal_inline_module_state_t, // + rCrD, v) { + return iree_hal_modules_buffer_view_trace(args->r0, args->a1_count, args->a1, + state->host_allocator); +} + +//===----------------------------------------------------------------------===// +// iree_hal_device_t +//===----------------------------------------------------------------------===// + +IREE_VM_ABI_EXPORT(iree_hal_inline_module_device_query_i64, // + iree_hal_inline_module_state_t, // + rr, iI) { + iree_vm_buffer_t* category = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r0, &category)); + iree_string_view_t category_str = iree_vm_buffer_as_string(category); + iree_vm_buffer_t* key = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &key)); + iree_string_view_t key_str = iree_vm_buffer_as_string(key); + + // TODO(benvanik): allow injection of a query function on the module. This + // would let us extend the queryable configuration with either synthetic + // properties or user-provided ones. For now we could at least provide + // compile-time configuration (like hosting architecture) but nothing dynamic + // (like cache sizes). + // The full HAL asks iree_hal_device_t but we don't have that here: + // iree_hal_device_query_i64(device, category_str, key_str, &value); + (void)category_str; + (void)key_str; + + int64_t value = 0; + iree_status_t query_status = iree_status_from_code(IREE_STATUS_NOT_FOUND); + rets->i0 = iree_status_consume_code(query_status) == IREE_STATUS_OK ? 1 : 0; + rets->i1 = value; + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// VM module interface implementation +//===----------------------------------------------------------------------===// + +// NOTE: this must match the ordering of the iree_hal_inline_module_exports_ +// table. +static const iree_vm_native_function_ptr_t iree_hal_inline_module_funcs_[] = { +#define EXPORT_FN(name, target_fn, arg_types, ret_types) \ + { \ + .shim = (iree_vm_native_function_shim_t) \ + iree_vm_shim_##arg_types##_##ret_types, \ + .target = (iree_vm_native_function_target_t)(target_fn), \ + }, +#include "iree/modules/hal/inline/exports.inl" // IWYU pragma: keep +#undef EXPORT_FN +}; + +// NOTE: 0 length, but can't express that in C. +static const iree_vm_native_import_descriptor_t + iree_hal_inline_module_imports_[1]; + +static const iree_vm_native_export_descriptor_t + iree_hal_inline_module_exports_[] = { +#define EXPORT_FN(name, target_fn, arg_types, ret_types) \ + { \ + .local_name = iree_string_view_literal(name), \ + .calling_convention = \ + iree_string_view_literal("0" #arg_types "_" #ret_types), \ + .attr_count = 0, \ + .attrs = NULL, \ + }, +#include "iree/modules/hal/inline/exports.inl" // IWYU pragma: keep +#undef EXPORT_FN +}; +static_assert(IREE_ARRAYSIZE(iree_hal_inline_module_funcs_) == + IREE_ARRAYSIZE(iree_hal_inline_module_exports_), + "function pointer table must be 1:1 with exports"); + +static const iree_vm_native_module_descriptor_t + iree_hal_inline_module_descriptor_ = { + .name = iree_string_view_literal("hal_inline"), + .version = IREE_HAL_INLINE_MODULE_VERSION_LATEST, + .attr_count = 0, + .attrs = NULL, + .dependency_count = 0, + .dependencies = NULL, + .import_count = 0, // workaround for 0-length C struct + .imports = iree_hal_inline_module_imports_, + .export_count = IREE_ARRAYSIZE(iree_hal_inline_module_exports_), + .exports = iree_hal_inline_module_exports_, + .function_count = IREE_ARRAYSIZE(iree_hal_inline_module_funcs_), + .functions = iree_hal_inline_module_funcs_, +}; + +IREE_API_EXPORT iree_status_t iree_hal_inline_module_create( + iree_vm_instance_t* instance, iree_hal_inline_module_flags_t flags, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, + iree_vm_module_t** out_module) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(device_allocator); + IREE_ASSERT_ARGUMENT(out_module); + *out_module = NULL; + + // Setup the interface with the functions we implement ourselves. Any function + // we omit will be handled by the base native module. + static const iree_vm_module_t interface = { + .destroy = iree_hal_inline_module_destroy, + .alloc_state = iree_hal_inline_module_alloc_state, + .free_state = iree_hal_inline_module_free_state, + .notify = iree_hal_inline_module_notify, + }; + + // Allocate shared module state. + iree_host_size_t total_size = + iree_vm_native_module_size() + sizeof(iree_hal_inline_module_t); + iree_vm_module_t* base_module = NULL; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&base_module)); + memset(base_module, 0, total_size); + iree_status_t status = iree_vm_native_module_initialize( + &interface, &iree_hal_inline_module_descriptor_, instance, host_allocator, + base_module); + if (!iree_status_is_ok(status)) { + iree_allocator_free(host_allocator, base_module); + return status; + } + + iree_hal_inline_module_t* module = IREE_HAL_INLINE_MODULE_CAST(base_module); + module->host_allocator = host_allocator; + module->device_allocator = device_allocator; + iree_hal_allocator_retain(module->device_allocator); + module->flags = flags; + + *out_module = base_module; + return iree_ok_status(); +}
diff --git a/runtime/src/iree/modules/hal/inline/module.h b/runtime/src/iree/modules/hal/inline/module.h new file mode 100644 index 0000000..f8e881d --- /dev/null +++ b/runtime/src/iree/modules/hal/inline/module.h
@@ -0,0 +1,39 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_MODULES_HAL_INLINE_MODULE_H_ +#define IREE_MODULES_HAL_INLINE_MODULE_H_ + +#include <stdint.h> + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/types.h" +#include "iree/vm/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +enum iree_hal_inline_module_flag_bits_t { + IREE_HAL_INLINE_MODULE_FLAG_NONE = 0u, +}; +typedef uint32_t iree_hal_inline_module_flags_t; + +// Creates the inline HAL module for local execution. +// This provides ABI compatibility with the full HAL implementation in a much +// smaller footprint. The given |device_allocator| will be used for buffer +// allocations. +IREE_API_EXPORT iree_status_t iree_hal_inline_module_create( + iree_vm_instance_t* instance, iree_hal_inline_module_flags_t flags, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, + iree_vm_module_t** out_module); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_MODULES_HAL_INLINE_MODULE_H_
diff --git a/runtime/src/iree/modules/vmvx/BUILD b/runtime/src/iree/modules/vmvx/BUILD index aeef1e5..05ba13a 100644 --- a/runtime/src/iree/modules/vmvx/BUILD +++ b/runtime/src/iree/modules/vmvx/BUILD
@@ -20,6 +20,9 @@ hdrs = [ "module.h", ], + defines = [ + "IREE_HAVE_VMVX_MODULE", + ], textual_hdrs = [ "exports.inl", ],
diff --git a/runtime/src/iree/modules/vmvx/CMakeLists.txt b/runtime/src/iree/modules/vmvx/CMakeLists.txt index 3f49441..97d5239 100644 --- a/runtime/src/iree/modules/vmvx/CMakeLists.txt +++ b/runtime/src/iree/modules/vmvx/CMakeLists.txt
@@ -18,6 +18,8 @@ "exports.inl" SRCS "module.c" + DEFINES + "IREE_HAVE_VMVX_MODULE" DEPS iree::base iree::base::tracing
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD index 4f97fa5..605c1fb 100644 --- a/runtime/src/iree/tooling/BUILD +++ b/runtime/src/iree/tooling/BUILD
@@ -24,6 +24,7 @@ "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/hal", "//runtime/src/iree/modules/hal", + "//runtime/src/iree/modules/hal/inline", "//runtime/src/iree/vm", "//runtime/src/iree/vm:bytecode_module", ],
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt index e829f1e..f36aee9 100644 --- a/runtime/src/iree/tooling/CMakeLists.txt +++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -25,6 +25,7 @@ iree::base::tracing iree::hal iree::modules::hal + iree::modules::hal::inline iree::vm iree::vm::bytecode_module PUBLIC @@ -166,3 +167,11 @@ endif() ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### + +# We're co-opting the VMVX module loader option for this as the inline-static +# model is essentially just an inlined loader. +# These tooling targets are intended for iree-* tools and not end-user binaries +# where binary size or dependency constraints matter. +if(IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE) + target_link_libraries(iree_tooling_context_util INTERFACE iree_modules_vmvx_vmvx) +endif()
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c index 8883bde..85e4d7d 100644 --- a/runtime/src/iree/tooling/context_util.c +++ b/runtime/src/iree/tooling/context_util.c
@@ -13,10 +13,15 @@ #include "iree/base/internal/file_io.h" #include "iree/base/internal/flags.h" #include "iree/base/tracing.h" +#include "iree/modules/hal/inline/module.h" #include "iree/modules/hal/module.h" #include "iree/tooling/device_util.h" #include "iree/vm/bytecode_module.h" +#if defined(IREE_HAVE_VMVX_MODULE) +#include "iree/modules/vmvx/module.h" +#endif // IREE_HAVE_VMVX_MODULE + //===----------------------------------------------------------------------===// // Module loading //===----------------------------------------------------------------------===// @@ -142,6 +147,50 @@ return status; } +static iree_status_t iree_tooling_load_hal_inline_module( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_vm_module_t** out_module, + iree_hal_allocator_t** out_device_allocator) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(out_module); + IREE_ASSERT_ARGUMENT(out_device_allocator); + if (*out_device_allocator) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "inline HAL module cannot be used with other " + "primary HAL module types"); + } + *out_module = NULL; + *out_device_allocator = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Register required types before creating the module. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_module_register_inline_types(instance)); + + // Create default heap device allocator. + iree_hal_allocator_t* device_allocator = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_create_inline_device_allocator_from_flags( + host_allocator, &device_allocator)); + + // Create the module; it's immutable and can be reused but we don't do that in + // this tooling. + iree_hal_inline_module_flags_t flags = IREE_HAL_INLINE_MODULE_FLAG_NONE; + iree_vm_module_t* module = NULL; + iree_status_t status = iree_hal_inline_module_create( + instance, flags, device_allocator, host_allocator, &module); + + if (iree_status_is_ok(status)) { + *out_module = module; + *out_device_allocator = device_allocator; + } else { + iree_hal_allocator_release(device_allocator); + iree_vm_module_release(module); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + //===----------------------------------------------------------------------===// // Module management //===----------------------------------------------------------------------===// @@ -222,6 +271,13 @@ IREE_RETURN_IF_ERROR(iree_tooling_load_hal_async_module( state->instance, state->default_device_uri, state->host_allocator, &module, &state->device, &state->device_allocator)); + } else if (iree_string_view_equal(dependency->name, IREE_SV("hal_inline"))) { + IREE_RETURN_IF_ERROR(iree_tooling_load_hal_inline_module( + state->instance, state->host_allocator, &module, + &state->device_allocator)); + } else if (iree_string_view_equal(dependency->name, IREE_SV("vmvx"))) { + IREE_RETURN_IF_ERROR(iree_vmvx_module_create( + state->instance, state->host_allocator, &module)); } else if (iree_all_bits_set(dependency->flags, IREE_VM_MODULE_DEPENDENCY_FLAG_REQUIRED)) { // Required but not found; fail.
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 9605efa..57707d7 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c
@@ -14,6 +14,7 @@ IREE_VM_ABI_DEFINE_SHIM(r, iii); IREE_VM_ABI_DEFINE_SHIM(r, iiii); IREE_VM_ABI_DEFINE_SHIM(r, r); +IREE_VM_ABI_DEFINE_SHIM(r, rI); IREE_VM_ABI_DEFINE_SHIM(r, v); IREE_VM_ABI_DEFINE_SHIM(rCiD, i); IREE_VM_ABI_DEFINE_SHIM(rCrD, v); @@ -45,6 +46,7 @@ IREE_VM_ABI_DEFINE_SHIM(rr, r); IREE_VM_ABI_DEFINE_SHIM(rr, v); IREE_VM_ABI_DEFINE_SHIM(rr, ii); +IREE_VM_ABI_DEFINE_SHIM(rr, iI); IREE_VM_ABI_DEFINE_SHIM(rrr, iI); IREE_VM_ABI_DEFINE_SHIM(rrCirIID, r); IREE_VM_ABI_DEFINE_SHIM(rriCiD, v); @@ -55,6 +57,7 @@ IREE_VM_ABI_DEFINE_SHIM(rrirCID, v); IREE_VM_ABI_DEFINE_SHIM(rrirI, v); IREE_VM_ABI_DEFINE_SHIM(rrIrII, v); +IREE_VM_ABI_DEFINE_SHIM(rrIii, v); IREE_VM_ABI_DEFINE_SHIM(rrrIii, v); IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r); IREE_VM_ABI_DEFINE_SHIM(rIrrr, v); @@ -62,6 +65,8 @@ IREE_VM_ABI_DEFINE_SHIM(CrID, r); IREE_VM_ABI_DEFINE_SHIM(CrD, r); IREE_VM_ABI_DEFINE_SHIM(iCrD, i); +IREE_VM_ABI_DEFINE_SHIM(iI, rr); +IREE_VM_ABI_DEFINE_SHIM(irII, rr); IREE_VM_ABI_DEFINE_SHIM(v, i); IREE_VM_ABI_DEFINE_SHIM(v, r); IREE_VM_ABI_DEFINE_SHIM(v, v);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index 6e27ff9..f7f3f38 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h
@@ -346,6 +346,14 @@ int64_t i5; }); +IREE_VM_ABI_FIXED_STRUCT(rrIii, { + iree_vm_ref_t r0; + iree_vm_ref_t r1; + int64_t i2; + int32_t i3; + int32_t i4; +}); + IREE_VM_ABI_FIXED_STRUCT(rrrIii, { iree_vm_ref_t r0; iree_vm_ref_t r1; @@ -522,6 +530,7 @@ IREE_VM_ABI_DECLARE_SHIM(r, iii); IREE_VM_ABI_DECLARE_SHIM(r, iiii); IREE_VM_ABI_DECLARE_SHIM(r, r); +IREE_VM_ABI_DECLARE_SHIM(r, rI); IREE_VM_ABI_DECLARE_SHIM(r, v); IREE_VM_ABI_DECLARE_SHIM(rCiD, i); IREE_VM_ABI_DECLARE_SHIM(rCrD, v); @@ -553,6 +562,7 @@ IREE_VM_ABI_DECLARE_SHIM(rr, r); IREE_VM_ABI_DECLARE_SHIM(rr, v); IREE_VM_ABI_DECLARE_SHIM(rr, ii); +IREE_VM_ABI_DECLARE_SHIM(rr, iI); IREE_VM_ABI_DECLARE_SHIM(rrr, iI); IREE_VM_ABI_DECLARE_SHIM(rrCirIID, r); IREE_VM_ABI_DECLARE_SHIM(rriCiD, v); @@ -563,6 +573,7 @@ IREE_VM_ABI_DECLARE_SHIM(rrirCID, v); IREE_VM_ABI_DECLARE_SHIM(rrirI, v); IREE_VM_ABI_DECLARE_SHIM(rrIrII, v); +IREE_VM_ABI_DECLARE_SHIM(rrIii, v); IREE_VM_ABI_DECLARE_SHIM(rrrIii, v); IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r); IREE_VM_ABI_DECLARE_SHIM(rIrrr, v); @@ -570,6 +581,8 @@ IREE_VM_ABI_DECLARE_SHIM(CrID, r); IREE_VM_ABI_DECLARE_SHIM(CrD, r); IREE_VM_ABI_DECLARE_SHIM(iCrD, i); +IREE_VM_ABI_DECLARE_SHIM(iI, rr); +IREE_VM_ABI_DECLARE_SHIM(irII, rr); IREE_VM_ABI_DECLARE_SHIM(v, i); IREE_VM_ABI_DECLARE_SHIM(v, r); IREE_VM_ABI_DECLARE_SHIM(v, v);