Add a pass for parameterizing large constants (#15955)
This adds a pass for constructing a parameter archive out of the globals
present in IR. The expected flow for this pass is to hoist all constants
above a specified size (probably the default inlining size for a
constant in a kernel) into globals + initializers and then replacing
them with parameters. This can happen any time before, after, or
in-between const-eval and const-expr-hoisting.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index 88d3098..f087d55 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -86,6 +86,7 @@
"//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms",
"//compiler/src/iree/compiler/Pipelines:Options",
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 31d8f49..1dbc19c 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -103,6 +103,7 @@
iree::compiler::Dialect::Util::Analysis::DFX
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
+ iree::compiler::Modules::IO::Parameters::Transforms
iree::compiler::Pipelines::Options
iree::compiler::Utils
PUBLIC
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 7cbf97f..c2d14ee 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -170,6 +171,25 @@
transformOptions.buildConstEvalPassPipeline(pipeline);
}
+ // Export after const-eval. If the user wants to keep the input constants
+ // as is in the final parameter archive, they will probably want to disable
+ // const-eval, or could run this pass as preprocessing. There might be a
+ // configuration in the future where users want to limit const-eval to smaller
+ // constants that aren't exported and skip it for larger parameters, but this
+ // is a sensible place for the common case of wanting const-eval in the final
+ // artifact + archive.
+ if (!transformOptions.options.parameterArchiveExportPath.empty()) {
+ IREE::IO::Parameters::ExportParametersPassOptions exportParametersOptions;
+ exportParametersOptions.archivePath =
+ transformOptions.options.parameterArchiveExportPath;
+ exportParametersOptions.parameterScope =
+ transformOptions.options.parameterExportScope;
+ exportParametersOptions.minimumSize =
+ transformOptions.options.minimumParameterExportSize;
+ pipeline.addPass(IREE::IO::Parameters::createExportParametersPass(
+ exportParametersOptions));
+ }
+
if (transformOptions.options.numericPrecisionReduction) {
pipeline.addPass(createInferNumericNarrowingPass());
pipeline.addPass(createOptimizeNumericsPass());
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
new file mode 100644
index 0000000..8793f8c
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
@@ -0,0 +1,59 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "Transforms",
+ srcs = [
+ "ExportParameters.cpp",
+ "Passes.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ "Passes.h.inc",
+ ],
+ deps = [
+ ":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/io:parameter_index",
+ "//runtime/src/iree/io:scope_map",
+ "//runtime/src/iree/io/formats/irpa",
+ "//runtime/src/iree/tooling:parameter_util",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["--gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..c943d2c
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
@@ -0,0 +1,51 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Transforms
+ HDRS
+ "Passes.h"
+ "Passes.h.inc"
+ SRCS
+ "ExportParameters.cpp"
+ "Passes.cpp"
+ DEPS
+ ::PassesIncGen
+ LLVMSupport
+ MLIRArithDialect
+ MLIRIR
+ MLIRPass
+ MLIRSupport
+ MLIRTransformUtils
+ MLIRTransforms
+ iree::base
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::IR
+ iree::hal
+ iree::io::formats::irpa
+ iree::io::parameter_index
+ iree::io::scope_map
+ iree::tooling::parameter_util
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ --gen-pass-decls Passes.h.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
new file mode 100644
index 0000000..4887829
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
@@ -0,0 +1,321 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/base/api.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/io/formats/irpa/irpa_builder.h"
+#include "iree/tooling/parameter_util.h"
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileOutputBuffer.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/FileUtilities.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+#define GEN_PASS_DEF_EXPORTPARAMETERSPASS
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h.inc"
+
+namespace {
+
+// Hoists all serializable constants with storage size at least |minimumSize|
+// into their own globals with initial value equal to the constant value.
+static void hoistConstantsIntoGlobals(mlir::ModuleOp moduleOp,
+ int64_t minimumSize) {
+ SymbolTable moduleSymbols(moduleOp);
+ IRRewriter rewriter(OpBuilder::atBlockBegin(moduleOp.getBody()));
+ llvm::DenseMap<arith::ConstantOp, Util::GlobalOp> hoistedMap;
+ moduleOp.walk([&](arith::ConstantOp constant) {
+ // Constants part of a different logical program should not be hoisted.
+ if (SymbolTable::getNearestSymbolTable(constant) != moduleOp) {
+ return;
+ }
+ TypedAttr initialValueAttr = constant.getValue();
+ auto serializableAttr =
+ dyn_cast<IREE::Util::SerializableAttrInterface>(initialValueAttr);
+ if (!serializableAttr) {
+ return;
+ }
+
+ // Check that the serialized size of the attribute is at least as big as
+ // the pass configured minimum storage size.
+ iree_io_physical_size_t storageSize = serializableAttr.getStorageSize();
+ if (storageSize < minimumSize) {
+ return;
+ }
+
+ // Create a new global with initial value equal to the constant.
+ Location loc = constant.getLoc();
+ Util::GlobalOp globalOp = rewriter.create<Util::GlobalOp>(
+ loc, "constant_hoisted", false, constant.getType());
+ moduleSymbols.insert(globalOp);
+ // Attributes are stored uniqued by their contents so this is not a copy.
+ globalOp.setInitialValueAttr(initialValueAttr);
+ SymbolTable::setSymbolVisibility(globalOp,
+ SymbolTable::Visibility::Private);
+ hoistedMap[constant] = globalOp;
+ });
+
+ // Replace all constants with their associated hoisted globals.
+ for (auto it : hoistedMap) {
+ arith::ConstantOp originalConstant = it.first;
+ Util::GlobalOp globalOp = it.second;
+ rewriter.setInsertionPointAfterValue(originalConstant);
+ Value load =
+ rewriter.create<Util::GlobalLoadOp>(globalOp->getLoc(), globalOp);
+ rewriter.replaceOp(originalConstant, load);
+ }
+}
+
+// Wrapper around iree_io_stream for use when serializing constants.
+class iree_io_stream_ostream : public llvm::raw_ostream {
+public:
+ explicit iree_io_stream_ostream(iree_io_stream_t *stream) : stream(stream) {
+ iree_io_stream_retain(stream);
+ }
+ ~iree_io_stream_ostream() override { iree_io_stream_release(stream); }
+
+private:
+ uint64_t current_pos() const override {
+ return iree_io_stream_offset(stream);
+ }
+ void write_impl(const char *ptr, size_t size) override {
+ IREE_CHECK_OK(iree_io_stream_write(stream, size, ptr));
+ }
+ iree_io_stream_t *stream = NULL;
+};
+
+static LogicalResult handleRuntimeError(ModuleOp moduleOp, iree_status_t status,
+ StringRef failureMessage) {
+ if (iree_status_is_ok(status))
+ return success();
+ std::string message;
+ message.resize(512);
+ iree_host_size_t buffer_length;
+ if (!iree_status_format(status, message.size(), &message[0],
+ &buffer_length)) {
+ message.resize(buffer_length + 1);
+ iree_status_format(status, message.size(), &message[0], &buffer_length);
+ }
+ message.resize(buffer_length);
+ iree_status_ignore(status);
+ return moduleOp.emitError() << failureMessage << message;
+}
+
+struct ExportParametersPass
+ : public IREE::IO::Parameters::impl::ExportParametersPassBase<
+ ExportParametersPass> {
+ using IREE::IO::Parameters::impl::ExportParametersPassBase<
+ ExportParametersPass>::ExportParametersPassBase;
+
+ void runOnOperation() override {
+ // Nothing to do if no path specified.
+ if (archivePath.empty()) {
+ return;
+ }
+
+ MLIRContext *context = &getContext();
+ ModuleOp moduleOp = getOperation();
+
+ // First hoist all inline constants into their own globals.
+ hoistConstantsIntoGlobals(moduleOp, minimumSize);
+
+ iree_allocator_t host_allocator = iree_allocator_system();
+
+ // Create the parameter archive builder.
+ iree_io_parameter_archive_builder_t builder;
+ iree_io_parameter_archive_builder_initialize(host_allocator, &builder);
+
+ auto deinitializeExit = llvm::make_scope_exit([&]() {
+ return iree_io_parameter_archive_builder_deinitialize(&builder);
+ });
+
+ SmallVector<IREE::Util::GlobalOp> constantGlobals;
+ // Walk the globals in the module.
+ for (auto global : moduleOp.getOps<IREE::Util::GlobalOp>()) {
+ // TODO: Support exporting mutable globals.
+ if (global.getIsMutable()) {
+ continue;
+ }
+ // Only globals initialized with initial values can be parameterized.
+ auto initialValueAttr = global.getInitialValueAttr();
+ if (!initialValueAttr) {
+ continue;
+ }
+
+ // The attribute must be serializable to be turned into a parameter.
+ auto serializableAttr =
+ dyn_cast<IREE::Util::SerializableAttrInterface>(initialValueAttr);
+ if (!serializableAttr) {
+ continue;
+ }
+
+ // Check that the serialized size of the attribute is at least as big as
+ // the pass configured minimum storage size.
+ iree_io_physical_size_t storageSize = serializableAttr.getStorageSize();
+ if (storageSize < minimumSize) {
+ continue;
+ }
+ StringRef name = global.getSymName();
+
+ // Add a data entry to the builder for this global.
+ iree_status_t status = iree_io_parameter_archive_builder_add_data_entry(
+ &builder,
+ iree_string_view_t{name.data(),
+ static_cast<iree_host_size_t>(name.size())},
+ /*metadata=*/iree_const_byte_span_empty(),
+ /*alignment=*/IREE_IO_PARAMETER_ARCHIVE_DEFAULT_DATA_ALIGNMENT,
+ storageSize);
+ if (failed(handleRuntimeError(moduleOp, status,
+ "Failed to add data entry for global"))) {
+ return signalPassFailure();
+ }
+
+ constantGlobals.push_back(global);
+ }
+
+ // Early exit if no parameterizable globals present.
+ if (constantGlobals.empty()) {
+ return;
+ }
+
+ // Open a file of sufficient size (now that we know it) for writing.
+ iree_io_physical_size_t archive_length =
+ iree_io_parameter_archive_builder_total_size(&builder);
+
+ auto FileOrErr =
+ llvm::FileOutputBuffer::create(archivePath, archive_length);
+ if (!FileOrErr) {
+ moduleOp.emitError()
+ << "Failed to create file output buffer at " << archivePath
+ << " with error: "
+ << llvm::errorToErrorCode(FileOrErr.takeError()).message();
+ return signalPassFailure();
+ }
+ std::unique_ptr<llvm::FileOutputBuffer> FileBuffer = std::move(*FileOrErr);
+
+ // Wrap the output file for use with the parameter archive builder.
+ iree_io_file_handle_t *target_file_handle = NULL;
+ iree_byte_span_t file_contents = iree_make_byte_span(
+ FileBuffer->getBufferStart(), FileBuffer->getBufferSize());
+ // Release callback is a no-op, the mapping is managed by the unique_ptr.
+ iree_status_t status = iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_WRITE, file_contents,
+ iree_io_file_handle_release_callback_null(), host_allocator,
+ &target_file_handle);
+ auto releaseFileExit = llvm::make_scope_exit(
+ [&]() { return iree_io_file_handle_release(target_file_handle); });
+ if (failed(handleRuntimeError(moduleOp, status,
+ "Failed to open output parameter archive"))) {
+ return signalPassFailure();
+ }
+
+ // Wrap the target file in a stream.
+ iree_io_stream_t *target_stream = NULL;
+ status =
+ iree_io_stream_open(IREE_IO_STREAM_MODE_WRITABLE, target_file_handle,
+ /*file_offset=*/0, host_allocator, &target_stream);
+ auto releaseStreamExit = llvm::make_scope_exit(
+ [&]() { return iree_io_stream_release(target_stream); });
+ if (failed(handleRuntimeError(
+ moduleOp, status, "Failed to create I/O stream to output file"))) {
+ return signalPassFailure();
+ }
+
+ // Allocate an index we'll populate during building to allow us to get the
+ // storage ranges of non-metadata parameters.
+ iree_io_parameter_index_t *built_index = NULL;
+ status = iree_io_parameter_index_create(host_allocator, &built_index);
+ auto releaseIndexExit = llvm::make_scope_exit(
+ [&]() { return iree_io_parameter_index_release(built_index); });
+ if (failed(handleRuntimeError(moduleOp, status,
+ "Failed to allocate parameter index"))) {
+ return signalPassFailure();
+ }
+
+ // Commit the archive header to the file and produce an index referencing
+ // it. This will allow us to know where to copy file contents.
+ status = iree_io_parameter_archive_builder_write(
+ &builder, target_file_handle, /*file_offset=*/0, target_stream,
+ built_index);
+ if (failed(handleRuntimeError(
+ moduleOp, status,
+ "Failed to write parameter index header to output file"))) {
+ return signalPassFailure();
+ }
+
+ StringAttr scopeAttr = parameterScope.empty()
+ ? StringAttr()
+ : StringAttr::get(context, parameterScope);
+ iree_io_stream_ostream llvm_stream(target_stream);
+
+ // Write all of the global contents to the appropriate data storage
+ // segments.
+ for (auto constantGlobal : constantGlobals) {
+ StringRef name = constantGlobal.getSymName();
+
+ const iree_io_parameter_index_entry_t *target_entry = NULL;
+ status = iree_io_parameter_index_lookup(
+ built_index,
+ iree_string_view_t{name.data(),
+ static_cast<iree_host_size_t>(name.size())},
+ &target_entry);
+ if (failed(handleRuntimeError(
+ moduleOp, status,
+ "Failed to write parameter index header to output file"))) {
+ return signalPassFailure();
+ }
+ status = iree_io_stream_seek(target_stream, IREE_IO_STREAM_SEEK_SET,
+ target_entry->storage.file.offset);
+ if (failed(handleRuntimeError(
+ moduleOp, status,
+ "Failed to seek to location of global in index"))) {
+ return signalPassFailure();
+ }
+
+ auto initialValueAttr = constantGlobal.getInitialValueAttr();
+ auto serializableAttr =
+ dyn_cast<IREE::Util::SerializableAttrInterface>(initialValueAttr);
+
+ if (failed(serializableAttr.serializeToStream(constantGlobal.getLoc(),
+ llvm::endianness::native,
+ llvm_stream))) {
+ moduleOp.emitError() << "Failed to serialize global " << constantGlobal;
+ return signalPassFailure();
+ }
+ llvm_stream.flush();
+
+ // Now we can just replace the existing initial value with a reference to
+ // the parameter.
+ auto param = IREE::Stream::NamedParameterAttr::get(
+ context, constantGlobal.getType(), scopeAttr,
+ StringAttr::get(context, name), DictionaryAttr());
+ constantGlobal.setInitialValueAttr(param);
+ }
+ // Commit the written file.
+ llvm::Error maybeCommit = FileBuffer->commit();
+ if (maybeCommit) {
+ InFlightDiagnostic errorStream =
+ moduleOp.emitError() << "Failed to commit archive with error: ";
+ llvm::handleAllErrors(std::move(maybeCommit),
+ [&](const llvm::ErrorInfoBase &PE) {
+ errorStream << PE.message() << "\n";
+ });
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.cpp
new file mode 100644
index 0000000..7518a8b
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.cpp
@@ -0,0 +1,23 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
+
+#include "mlir/Pass/PassRegistry.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h.inc" // IWYU pragma: export
+} // namespace
+
+void registerParametersPasses() {
+ // Generated.
+ registerPasses();
+}
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.h
new file mode 100644
index 0000000..9493d93
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.h
@@ -0,0 +1,32 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_PASSES_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_PASSES_H_
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+//// Moves all global initial values to a parameter archive.
+// std::unique_ptr<Pass>
+// createParameterizeGlobalsPass(std::string archivePath = "",
+// std::string parameterNamespace = "");
+
+//===----------------------------------------------------------------------===//
+// Register all Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h.inc" // IWYU pragma: keep
+
+void registerParametersPasses();
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_PASSES_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
new file mode 100644
index 0000000..1894fcc
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
@@ -0,0 +1,33 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_MODULES_IO_PARAMETERS_PASSES
+#define IREE_MODULES_IO_PARAMETERS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ExportParametersPass :
+ Pass<"iree-io-export-parameters", "mlir::ModuleOp"> {
+ let summary = "Moves all inline constants of a minimum size and constant "
+ "initialized globals values to a parameter archive";
+ let dependentDialects = [
+ "IREE::Stream::StreamDialect",
+ "IREE::Util::UtilDialect",
+ ];
+ let options = [
+ Option<"parameterScope", "scope", "std::string",
+ /*default=*/"",
+ "Optional scope to use for the exported parameters.">,
+ Option<"archivePath", "archive-path", "std::string",
+ /*default=*/"",
+ "Path to write the parameter archive to.">,
+ Option<"minimumSize", "minimum-size", "int64_t",
+ /*default=*/"256",
+ "Minimum size of a serialized global to export.">
+ ];
+}
+
+#endif // IREE_MODULES_IO_PARAMETERS_PASSES
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel
new file mode 100644
index 0000000..96c6c6d
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "export_parameters.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt
new file mode 100644
index 0000000..c5dc94f
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "export_parameters.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir
new file mode 100644
index 0000000..1a474d0
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir
@@ -0,0 +1,26 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-io-export-parameters{archive-path="%t.irpa" scope=opt minimum-size=0})" %s | FileCheck %s
+
+// CHECK-LABEL: module @parameter_example
+module @parameter_example {
+// CHECK-DAG: util.global private @array_global_0 = #stream.parameter.named<"opt"::"array_global_0"> : tensor<1x2xf32>
+// CHECK-DAG: util.global private @dense_global_1 = #stream.parameter.named<"opt"::"dense_global_1"> : tensor<2x2xf32>
+// CHECK-DAG: util.global private @constant_hoisted = #stream.parameter.named<"opt"::"constant_hoisted"> : tensor<1x2xf32>
+// CHECK-DAG: util.global private @dense_global_2 = #stream.parameter.named<"opt"::"dense_global_2"> : tensor<2x2xf32>
+ util.global private @array_global_0 = dense<[[11.0, 12.0]]> : tensor<1x2xf32>
+ util.global private @dense_global_1 = dense<"0x0000E040000000410000104100002041"> : tensor<2x2xf32>
+ util.global private @dense_global_2 = dense<"0x0000803F000000400000404000008040"> : tensor<2x2xf32>
+ func.func @parameter_example(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %3 = util.global.load @array_global_0 : tensor<1x2xf32>
+ %4 = util.global.load @dense_global_1 : tensor<2x2xf32>
+ %5 = arith.constant dense<"0x0000A0400000C040"> : tensor<1x2xf32>
+ %6 = util.global.load @dense_global_2 : tensor<2x2xf32>
+ %empty = tensor.empty() : tensor<1x2xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %8 = linalg.matmul ins(%arg0, %6 : tensor<1x2xf32>, tensor<2x2xf32>) outs(%fill : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %10 = linalg.add ins(%8, %5 : tensor<1x2xf32>, tensor<1x2xf32>) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %12 = linalg.matmul ins(%10, %4 : tensor<1x2xf32>, tensor<2x2xf32>) outs(%fill : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %14 = linalg.add ins(%12, %3 : tensor<1x2xf32>, tensor<1x2xf32>) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ return %14 : tensor<1x2xf32>
+ }
+}
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index 359316c..f24086d 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -133,6 +133,23 @@
llvm::cl::desc("Strips debug assertions after any useful "
"information has been extracted."),
llvm::cl::cat(category));
+ binder.opt<std::string>(
+ "iree-opt-parameter-archive-export-file", parameterArchiveExportPath,
+ llvm::cl::desc(
+ "File path to create a parameter archive using any inline global "
+ "constants."),
+ llvm::cl::cat(category));
+ binder.opt<std::string>(
+ "iree-opt-parameter-archive-export-scope", parameterExportScope,
+ llvm::cl::desc("Scope for parameters in the archive created in "
+ "`iree-opt-export-parameter-archive-export-file`."),
+ llvm::cl::cat(category));
+ binder.opt<int64_t>(
+ "iree-opt-minimum-parameter-export-size", minimumParameterExportSize,
+ llvm::cl::desc(
+ "Minimum size of constants to export to the archive created in "
+ "`iree-opt-export-parameter-archive-export-file`."),
+ llvm::cl::cat(category));
}
void SchedulingOptions::bindOptions(OptionsBinder &binder) {
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 5cfa333..01698b7 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -94,6 +94,15 @@
// allow hoisting. The threshold is 1MB by default.
int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024;
+ // File path to create a parameter archive out of global initial values.
+ std::string parameterArchiveExportPath = "";
+
+ // Optional scope to use for the created parameter archive.
+ std::string parameterExportScope = "";
+
+ // Minimum size of constants to export as parameters.
+ int64_t minimumParameterExportSize = 256;
+
void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
};
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index f44e87c..d2a145d 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -60,6 +60,7 @@
"//compiler/src/iree/compiler/Modules/HAL/Loader/IR:HALLoaderDialect",
"//compiler/src/iree/compiler/Modules/HAL/Loader/Transforms",
"//compiler/src/iree/compiler/Modules/IO/Parameters/IR:IOParametersDialect",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms",
"//compiler/src/iree/compiler/Pipelines",
"//compiler/src/iree/compiler/Preprocessing:Passes",
"//llvm-external-projects/iree-dialects:IREEInputDialect",
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index b82934c..97a3a92 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -67,6 +67,7 @@
iree::compiler::Modules::HAL::Loader::IR::HALLoaderDialect
iree::compiler::Modules::HAL::Loader::Transforms
iree::compiler::Modules::IO::Parameters::IR::IOParametersDialect
+ iree::compiler::Modules::IO::Parameters::Transforms
iree::compiler::Pipelines
iree::compiler::Preprocessing::Passes
PUBLIC
diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h
index 41f8c54..d9849da 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h
@@ -29,6 +29,7 @@
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h"
#include "iree/compiler/Modules/HAL/Loader/Transforms/Passes.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
#include "iree/compiler/Pipelines/Pipelines.h"
#include "iree/compiler/Preprocessing/Passes.h"
@@ -55,6 +56,7 @@
IREE::HAL::registerHALPasses();
IREE::HAL::Inline::registerHALInlinePasses();
IREE::HAL::Loader::registerHALLoaderPasses();
+ IREE::IO::Parameters::registerParametersPasses();
IREE::LinalgExt::registerPasses();
IREE::Stream::registerStreamPasses();
IREE::Util::registerTransformPasses();
diff --git a/tests/e2e/parameters/BUILD.bazel b/tests/e2e/parameters/BUILD.bazel
new file mode 100644
index 0000000..d5d5c1e
--- /dev/null
+++ b/tests/e2e/parameters/BUILD.bazel
@@ -0,0 +1,33 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "export_parameters.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//tests:lit.cfg.py",
+ tags = [
+ "driver=local-task",
+ ],
+ tools = [
+ "//tools:iree-compile",
+ "//tools:iree-dump-parameters",
+ "//tools:iree-run-module",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/tests/e2e/parameters/CMakeLists.txt b/tests/e2e/parameters/CMakeLists.txt
new file mode 100644
index 0000000..ee55f74
--- /dev/null
+++ b/tests/e2e/parameters/CMakeLists.txt
@@ -0,0 +1,28 @@
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_ABOVE_THIS_LINE ###
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# tests/e2e/parameters/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "export_parameters.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+ iree-dump-parameters
+ iree-run-module
+ LABELS
+ "driver=local-task"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/tests/e2e/parameters/export_parameters.mlir b/tests/e2e/parameters/export_parameters.mlir
new file mode 100644
index 0000000..cbf292e
--- /dev/null
+++ b/tests/e2e/parameters/export_parameters.mlir
@@ -0,0 +1,41 @@
+module @parameter_example {
+ util.global private @array_global_0 = dense<[[11.0, 12.0]]> : tensor<1x2xf32>
+ util.global private @dense_global_1 = dense<"0x0000E040000000410000104100002041"> : tensor<2x2xf32>
+ util.global private @dense_global_2 = dense<"0x0000A0400000C040"> : tensor<1x2xf32>
+ util.global private @dense_global_3 = dense<"0x0000803F000000400000404000008040"> : tensor<2x2xf32>
+ func.func @predict(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %3 = util.global.load @array_global_0 : tensor<1x2xf32>
+ %4 = util.global.load @dense_global_1 : tensor<2x2xf32>
+ %5 = util.global.load @dense_global_2 : tensor<1x2xf32>
+ %6 = util.global.load @dense_global_3 : tensor<2x2xf32>
+ %empty = tensor.empty() : tensor<1x2xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %8 = linalg.matmul ins(%arg0, %6 : tensor<1x2xf32>, tensor<2x2xf32>) outs(%fill : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %10 = linalg.add ins(%8, %5 : tensor<1x2xf32>, tensor<1x2xf32>) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %12 = linalg.matmul ins(%10, %4 : tensor<1x2xf32>, tensor<2x2xf32>) outs(%fill : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %14 = linalg.add ins(%12, %3 : tensor<1x2xf32>, tensor<1x2xf32>) outs(%empty : tensor<1x2xf32>) -> tensor<1x2xf32>
+ return %14 : tensor<1x2xf32>
+ }
+}
+
+// RUN: iree-compile %s \
+// RUN: --iree-hal-target-backends=vmvx \
+// RUN: --iree-opt-parameter-archive-export-file=%t.irpa \
+// RUN: --iree-opt-parameter-archive-export-scope=compile \
+// RUN: --iree-opt-minimum-parameter-export-size=0 | \
+// RUN: iree-run-module --device=local-task --module=- \
+// RUN: --input=1x2xf32=1.0 \
+// RUN: --parameters=compile=%t.irpa \
+// RUN: --function=predict > %t.txt
+// RUN: iree-dump-parameters --parameters=compile=%t.irpa >> %t.txt
+// RUN: FileCheck %s --input-file=%t.txt
+
+// CHECK-LABEL: EXEC @predict
+// CHECK: 1x2xf32=[182 204]
+
+// CHECK: 512 |{{.*}} 520 |{{.*}} 8 | `constant_hoisted`
+// CHECK: 576 |{{.*}} 584 |{{.*}} 8 | `constant_hoisted_0`
+// CHECK: 640 |{{.*}} 656 |{{.*}} 16 | `constant_hoisted_1`
+// CHECK: 704 |{{.*}} 720 |{{.*}} 16 | `constant_hoisted_2`
+