Add a pass for generating splat archives during compilation (#16353)
This pass is useful for benchmarking and debugging models with
parameterized weights without needing to download multiple gigabytes of
weights first.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index b37c32e..e53f7cd 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -190,6 +190,16 @@
exportParametersOptions));
}
+ if (!transformOptions.options.splatParameterArchiveExportPath.empty()) {
+ IREE::IO::Parameters::GenerateSplatParameterArchivePassOptions
+ generateSplatOptions;
+ generateSplatOptions.archivePath =
+ transformOptions.options.splatParameterArchiveExportPath;
+ pipeline.addPass(
+ IREE::IO::Parameters::createGenerateSplatParameterArchivePass(
+ generateSplatOptions));
+ }
+
if (transformOptions.options.numericPrecisionReduction) {
pipeline.addPass(createInferNumericNarrowingPass());
pipeline.addPass(createOptimizeNumericsPass());
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp
new file mode 100644
index 0000000..16dde66
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp
@@ -0,0 +1,101 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/base/api.h"
+#include "iree/io/formats/irpa/irpa_builder.h"
+#include "iree/tooling/parameter_util.h"
+
+#include "iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h"
+#include "llvm/Support/FileOutputBuffer.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+LogicalResult handleRuntimeError(Operation *op, iree_status_t status,
+ StringRef failureMessage) {
+ if (iree_status_is_ok(status))
+ return success();
+ std::string message;
+ message.resize(512);
+ iree_host_size_t buffer_length;
+ if (!iree_status_format(status, message.size(), &message[0],
+ &buffer_length)) {
+ message.resize(buffer_length + 1);
+ iree_status_format(status, message.size(), &message[0], &buffer_length);
+ }
+ message.resize(buffer_length);
+ iree_status_ignore(status);
+ return op->emitError() << failureMessage << message;
+}
+
+LogicalResult
+writeParameterIndex(Operation *op, iree_allocator_t allocator,
+ iree_io_parameter_archive_builder_t &builder,
+ std::unique_ptr<llvm::FileOutputBuffer> &fileBuffer,
+ iree_io_file_handle_t **output_file_handle,
+ iree_io_stream_t **output_stream,
+ iree_io_parameter_index_t **output_built_index) {
+
+ // Wrap the output file for use with the parameter archive builder.
+ iree_io_file_handle_t *target_file_handle = NULL;
+ iree_byte_span_t file_contents = iree_make_byte_span(
+ fileBuffer->getBufferStart(), fileBuffer->getBufferSize());
+ // Release callback is a no-op, the mapping is managed by the unique_ptr.
+ iree_status_t status = iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_WRITE, file_contents,
+ iree_io_file_handle_release_callback_null(), allocator,
+ &target_file_handle);
+ if (failed(handleRuntimeError(op, status,
+ "Failed to open output parameter archive"))) {
+ iree_io_file_handle_release(target_file_handle);
+ return failure();
+ }
+
+ // Wrap the target file in a stream.
+ iree_io_stream_t *target_stream = NULL;
+ status = iree_io_stream_open(IREE_IO_STREAM_MODE_WRITABLE, target_file_handle,
+ /*file_offset=*/0, allocator, &target_stream);
+ if (failed(handleRuntimeError(
+ op, status, "Failed to create I/O stream to output file"))) {
+ iree_io_file_handle_release(target_file_handle);
+ iree_io_stream_release(target_stream);
+ return failure();
+ }
+
+ // Allocate an index we'll populate during building to allow us to get the
+ // storage ranges of non-metadata parameters.
+ iree_io_parameter_index_t *built_index = NULL;
+ status = iree_io_parameter_index_create(allocator, &built_index);
+ if (failed(handleRuntimeError(op, status,
+ "Failed to allocate parameter index"))) {
+ iree_io_file_handle_release(target_file_handle);
+ iree_io_stream_release(target_stream);
+ iree_io_parameter_index_release(built_index);
+ return failure();
+ }
+
+ // Commit the archive header to the file and produce an index referencing
+ // it. This will allow us to know where to copy file contents.
+ status = iree_io_parameter_archive_builder_write(&builder, target_file_handle,
+ /*file_offset=*/0,
+ target_stream, built_index);
+ if (failed(handleRuntimeError(
+ op, status,
+ "Failed to write parameter index header to output file"))) {
+ iree_io_file_handle_release(target_file_handle);
+ iree_io_stream_release(target_stream);
+ iree_io_parameter_index_release(built_index);
+ return failure();
+ }
+
+ *output_file_handle = target_file_handle;
+ *output_stream = target_stream;
+ *output_built_index = built_index;
+ return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h
new file mode 100644
index 0000000..1345e68
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h
@@ -0,0 +1,56 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_ARCHIVEUTILS_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_ARCHIVEUTILS_H_
+
+#include "iree/base/api.h"
+#include "iree/io/formats/irpa/irpa_builder.h"
+#include "iree/tooling/parameter_util.h"
+
+#include "llvm/Support/FileOutputBuffer.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+// Wrapper around iree_io_stream for use when serializing constants.
+class iree_io_stream_ostream : public llvm::raw_ostream {
+public:
+ explicit iree_io_stream_ostream(iree_io_stream_t *stream) : stream(stream) {
+ iree_io_stream_retain(stream);
+ }
+ ~iree_io_stream_ostream() override { iree_io_stream_release(stream); }
+
+private:
+ uint64_t current_pos() const override {
+ return iree_io_stream_offset(stream);
+ }
+ void write_impl(const char *ptr, size_t size) override {
+ IREE_CHECK_OK(iree_io_stream_write(stream, size, ptr));
+ }
+ iree_io_stream_t *stream = NULL;
+};
+
+// Helper to interpret iree status messages and print the error message.
+LogicalResult handleRuntimeError(Operation *op, iree_status_t status,
+ StringRef failureMessage);
+
+// Helper to write the parameter index constructed in the archive |builder|
+// to the given |fileBuffer|. Populates a file, stream, and index handle on
+// success for further writing of the data segments. The file, stream, and
+// index handled must be released by the caller if this succeeds.
+LogicalResult
+writeParameterIndex(Operation *op, iree_allocator_t allocator,
+ iree_io_parameter_archive_builder_t &builder,
+ std::unique_ptr<llvm::FileOutputBuffer> &fileBuffer,
+ iree_io_file_handle_t **output_file_handle,
+ iree_io_stream_t **output_stream,
+ iree_io_parameter_index_t **output_built_index);
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_TRANSFORMS_ARCHIVEUTILS_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
index 8793f8c..caef38d 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
@@ -15,10 +15,13 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
+ "ArchiveUtils.cpp",
"ExportParameters.cpp",
+ "GenerateSplatParameterArchive.cpp",
"Passes.cpp",
],
hdrs = [
+ "ArchiveUtils.h",
"Passes.h",
"Passes.h.inc",
],
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
index c943d2c..0a3ac29 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
@@ -14,10 +14,13 @@
NAME
Transforms
HDRS
+ "ArchiveUtils.h"
"Passes.h"
"Passes.h.inc"
SRCS
+ "ArchiveUtils.cpp"
"ExportParameters.cpp"
+ "GenerateSplatParameterArchive.cpp"
"Passes.cpp"
DEPS
::PassesIncGen
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
index 38ede9a..4f7eeaa 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h"
#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Error.h"
@@ -78,41 +79,6 @@
}
}
-// Wrapper around iree_io_stream for use when serializing constants.
-class iree_io_stream_ostream : public llvm::raw_ostream {
-public:
- explicit iree_io_stream_ostream(iree_io_stream_t *stream) : stream(stream) {
- iree_io_stream_retain(stream);
- }
- ~iree_io_stream_ostream() override { iree_io_stream_release(stream); }
-
-private:
- uint64_t current_pos() const override {
- return iree_io_stream_offset(stream);
- }
- void write_impl(const char *ptr, size_t size) override {
- IREE_CHECK_OK(iree_io_stream_write(stream, size, ptr));
- }
- iree_io_stream_t *stream = NULL;
-};
-
-static LogicalResult handleRuntimeError(ModuleOp moduleOp, iree_status_t status,
- StringRef failureMessage) {
- if (iree_status_is_ok(status))
- return success();
- std::string message;
- message.resize(512);
- iree_host_size_t buffer_length;
- if (!iree_status_format(status, message.size(), &message[0],
- &buffer_length)) {
- message.resize(buffer_length + 1);
- iree_status_format(status, message.size(), &message[0], &buffer_length);
- }
- message.resize(buffer_length);
- iree_status_ignore(status);
- return moduleOp.emitError() << failureMessage << message;
-}
-
struct ExportParametersPass
: public IREE::IO::Parameters::impl::ExportParametersPassBase<
ExportParametersPass> {
@@ -203,57 +169,22 @@
<< llvm::errorToErrorCode(FileOrErr.takeError()).message();
return signalPassFailure();
}
- std::unique_ptr<llvm::FileOutputBuffer> FileBuffer = std::move(*FileOrErr);
+ std::unique_ptr<llvm::FileOutputBuffer> fileBuffer = std::move(*FileOrErr);
- // Wrap the output file for use with the parameter archive builder.
iree_io_file_handle_t *target_file_handle = NULL;
- iree_byte_span_t file_contents = iree_make_byte_span(
- FileBuffer->getBufferStart(), FileBuffer->getBufferSize());
- // Release callback is a no-op, the mapping is managed by the unique_ptr.
- iree_status_t status = iree_io_file_handle_wrap_host_allocation(
- IREE_IO_FILE_ACCESS_WRITE, file_contents,
- iree_io_file_handle_release_callback_null(), host_allocator,
- &target_file_handle);
- auto releaseFileExit = llvm::make_scope_exit(
- [&]() { return iree_io_file_handle_release(target_file_handle); });
- if (failed(handleRuntimeError(moduleOp, status,
- "Failed to open output parameter archive"))) {
- return signalPassFailure();
- }
-
- // Wrap the target file in a stream.
iree_io_stream_t *target_stream = NULL;
- status =
- iree_io_stream_open(IREE_IO_STREAM_MODE_WRITABLE, target_file_handle,
- /*file_offset=*/0, host_allocator, &target_stream);
- auto releaseStreamExit = llvm::make_scope_exit(
- [&]() { return iree_io_stream_release(target_stream); });
- if (failed(handleRuntimeError(
- moduleOp, status, "Failed to create I/O stream to output file"))) {
- return signalPassFailure();
- }
-
- // Allocate an index we'll populate during building to allow us to get the
- // storage ranges of non-metadata parameters.
iree_io_parameter_index_t *built_index = NULL;
- status = iree_io_parameter_index_create(host_allocator, &built_index);
- auto releaseIndexExit = llvm::make_scope_exit(
- [&]() { return iree_io_parameter_index_release(built_index); });
- if (failed(handleRuntimeError(moduleOp, status,
- "Failed to allocate parameter index"))) {
+ if (failed(writeParameterIndex(moduleOp, host_allocator, builder,
+ fileBuffer, &target_file_handle,
+ &target_stream, &built_index))) {
return signalPassFailure();
}
- // Commit the archive header to the file and produce an index referencing
- // it. This will allow us to know where to copy file contents.
- status = iree_io_parameter_archive_builder_write(
- &builder, target_file_handle, /*file_offset=*/0, target_stream,
- built_index);
- if (failed(handleRuntimeError(
- moduleOp, status,
- "Failed to write parameter index header to output file"))) {
- return signalPassFailure();
- }
+ auto releaseFileExit = llvm::make_scope_exit([&]() -> void {
+ iree_io_stream_release(target_stream);
+ iree_io_parameter_index_release(built_index);
+ iree_io_file_handle_release(target_file_handle);
+ });
StringAttr scopeAttr = parameterScope.empty()
? StringAttr()
@@ -266,7 +197,7 @@
StringRef name = constantGlobal.getSymName();
const iree_io_parameter_index_entry_t *target_entry = NULL;
- status = iree_io_parameter_index_lookup(
+ iree_status_t status = iree_io_parameter_index_lookup(
built_index,
iree_string_view_t{name.data(),
static_cast<iree_host_size_t>(name.size())},
@@ -304,7 +235,7 @@
constantGlobal.setInitialValueAttr(param);
}
// Commit the written file.
- llvm::Error maybeCommit = FileBuffer->commit();
+ llvm::Error maybeCommit = fileBuffer->commit();
if (maybeCommit) {
InFlightDiagnostic errorStream =
moduleOp.emitError() << "Failed to commit archive with error: ";
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp
new file mode 100644
index 0000000..d270372
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp
@@ -0,0 +1,141 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/base/api.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/io/formats/irpa/irpa_builder.h"
+#include "iree/tooling/parameter_util.h"
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h"
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileOutputBuffer.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Support/FileUtilities.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+#define GEN_PASS_DEF_GENERATESPLATPARAMETERARCHIVEPASS
+#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h.inc"
+
+namespace {
+
+struct GenerateSplatParameterArchivePass
+ : public IREE::IO::Parameters::impl::GenerateSplatParameterArchivePassBase<
+ GenerateSplatParameterArchivePass> {
+ using IREE::IO::Parameters::impl::GenerateSplatParameterArchivePassBase<
+ GenerateSplatParameterArchivePass>::GenerateSplatParameterArchivePassBase;
+
+ void runOnOperation() override {
+ // Nothing to do if no path specified.
+ if (archivePath.empty()) {
+ return;
+ }
+
+ ModuleOp moduleOp = getOperation();
+
+ iree_allocator_t host_allocator = iree_allocator_system();
+
+ // Create the parameter archive builder.
+ iree_io_parameter_archive_builder_t builder;
+ iree_io_parameter_archive_builder_initialize(host_allocator, &builder);
+
+ auto deinitializeExit = llvm::make_scope_exit([&]() {
+ return iree_io_parameter_archive_builder_deinitialize(&builder);
+ });
+
+ bool hasParameter = false;
+ // Walk the globals in the module.
+ for (auto global : moduleOp.getOps<IREE::Util::GlobalOp>()) {
+ // Look for globals backed by parameters.
+ auto initialValueAttr = global.getInitialValueAttr();
+ if (!initialValueAttr) {
+ continue;
+ }
+ auto parameterAttr =
+ dyn_cast<IREE::Stream::NamedParameterAttr>(initialValueAttr);
+ if (!parameterAttr) {
+ continue;
+ }
+
+ // Note that the scope is not a part of the parameter archive. If the
+ // module includes multiple scopes, multiple copies of the splat archive
+ // would need to be passed in with all possible scopes.
+ std::string parameterName = parameterAttr.getKey().str();
+ iree_io_physical_size_t storageSize = parameterAttr.getStorageSize();
+
+ // Add a zero-splat entry to the builder for this global.
+ char c0 = 0;
+ iree_status_t status = iree_io_parameter_archive_builder_add_splat_entry(
+ &builder,
+ iree_string_view_t{
+ parameterName.data(),
+ static_cast<iree_host_size_t>(parameterName.size())},
+ /*metadata=*/iree_const_byte_span_empty(),
+ /*pattern=*/&c0, /*pattern_length=*/1, /*data_length=*/storageSize);
+ if (failed(handleRuntimeError(moduleOp, status,
+ "Failed to add splate entry for global"))) {
+ return signalPassFailure();
+ }
+ hasParameter = true;
+ }
+
+ // Early exit if no parameter backed globals present.
+ if (!hasParameter) {
+ return;
+ }
+
+ // Open a file of sufficient size (now that we know it) for writing.
+ iree_io_physical_size_t archive_length =
+ iree_io_parameter_archive_builder_total_size(&builder);
+
+ auto FileOrErr =
+ llvm::FileOutputBuffer::create(archivePath, archive_length);
+ if (!FileOrErr) {
+ moduleOp.emitError()
+ << "Failed to create file output buffer at " << archivePath
+ << " with error: "
+ << llvm::errorToErrorCode(FileOrErr.takeError()).message();
+ return signalPassFailure();
+ }
+ std::unique_ptr<llvm::FileOutputBuffer> fileBuffer = std::move(*FileOrErr);
+
+ iree_io_file_handle_t *target_file_handle = NULL;
+ iree_io_stream_t *target_stream = NULL;
+ iree_io_parameter_index_t *built_index = NULL;
+ if (failed(writeParameterIndex(moduleOp, host_allocator, builder,
+ fileBuffer, &target_file_handle,
+ &target_stream, &built_index))) {
+ return signalPassFailure();
+ }
+
+ auto releaseFileExit = llvm::make_scope_exit([&]() -> void {
+ iree_io_stream_release(target_stream);
+ iree_io_parameter_index_release(built_index);
+ iree_io_file_handle_release(target_file_handle);
+ });
+
+ // Commit the written file.
+ llvm::Error maybeCommit = fileBuffer->commit();
+ if (maybeCommit) {
+ InFlightDiagnostic errorStream =
+ moduleOp.emitError() << "Failed to commit archive with error: ";
+ llvm::handleAllErrors(std::move(maybeCommit),
+ [&](const llvm::ErrorInfoBase &PE) {
+ errorStream << PE.message() << "\n";
+ });
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
index 1894fcc..61b2247 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
@@ -30,4 +30,15 @@
];
}
+def GenerateSplatParameterArchivePass :
+ Pass<"iree-io-generate-splat-parameter-archive", "mlir::ModuleOp"> {
+ let summary = "Generates a .irpa file with splat entries for all parameters";
+ let dependentDialects = [];
+ let options = [
+ Option<"archivePath", "archive-path", "std::string",
+ /*default=*/"",
+ "Path to write the parameter archive to.">
+ ];
+}
+
#endif // IREE_MODULES_IO_PARAMETERS_PASSES
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel
index 96c6c6d..7b6620f 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel
@@ -17,11 +17,13 @@
srcs = enforce_glob(
[
"export_parameters.mlir",
+ "generate_splat_parameter_archive.mlir",
],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
tools = [
+ "//tools:iree-dump-parameters",
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
],
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt
index c5dc94f..72297bc 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/CMakeLists.txt
@@ -15,8 +15,10 @@
lit
SRCS
"export_parameters.mlir"
+ "generate_splat_parameter_archive.mlir"
TOOLS
FileCheck
+ iree-dump-parameters
iree-opt
)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir
new file mode 100644
index 0000000..c99d390
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-io-generate-splat-parameter-archive{archive-path="%t.irpa"})" %s | FileCheck %s
+// RUN: iree-dump-parameters --parameters=%t.irpa | FileCheck %s --check-prefix=DUMP
+
+// CHECK-LABEL: @parameter_example
+module @parameter_example {
+ // CHECK: util.global private @array_global_0 = #stream.parameter.named<"model"::"global_0">
+ // CHECK: util.global private @dense_global_1 = #stream.parameter.named<"model"::"global_1">
+ // CHECK: util.global private @dense_global_2 = #stream.parameter.named<"model"::"global_2">
+ // CHECK: util.global private @dense_global_3 = #stream.parameter.named<"model"::"global_3">
+ util.global private @array_global_0 = #stream.parameter.named<"model"::"global_0"> : tensor<1x2xi32>
+ util.global private @dense_global_1 = #stream.parameter.named<"model"::"global_1"> : tensor<2x2xi32>
+ util.global private @dense_global_2 = #stream.parameter.named<"model"::"global_2"> : tensor<1x2xi32>
+ util.global private @dense_global_3 = #stream.parameter.named<"model"::"global_3"> : tensor<2x2xi32>
+ func.func @forward(%arg0: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %cst = arith.constant 0 : i32
+ %3 = util.global.load @array_global_0 : tensor<1x2xi32>
+ %4 = util.global.load @dense_global_1 : tensor<2x2xi32>
+ %5 = util.global.load @dense_global_2 : tensor<1x2xi32>
+ %6 = util.global.load @dense_global_3 : tensor<2x2xi32>
+ %empty = tensor.empty() : tensor<1x2xi32>
+ %fill = linalg.fill ins(%cst : i32) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %8 = linalg.matmul ins(%arg0, %6 : tensor<1x2xi32>, tensor<2x2xi32>) outs(%fill : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %10 = linalg.add ins(%8, %5 : tensor<1x2xi32>, tensor<1x2xi32>) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %12 = linalg.matmul ins(%10, %4 : tensor<1x2xi32>, tensor<2x2xi32>) outs(%fill : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %14 = linalg.add ins(%12, %3 : tensor<1x2xi32>, tensor<1x2xi32>) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ return %14 : tensor<1x2xi32>
+ }
+}
+
+// Verify the generated archive is what we expect.
+// DUMP: - |{{.*}} - |{{.*}} 8 | `global_0`
+// DUMP: - |{{.*}} - |{{.*}} 16 | `global_1`
+// DUMP: - |{{.*}} - |{{.*}} 8 | `global_2`
+// DUMP: - |{{.*}} - |{{.*}} 16 | `global_3`
+
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index ac8cf2e..f88138a 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -150,6 +150,13 @@
"Minimum size of constants to export to the archive created in "
"`iree-opt-export-parameter-archive-export-file`."),
llvm::cl::cat(category));
+ binder.opt<std::string>(
+ "iree-opt-splat-parameter-archive-export-file",
+ splatParameterArchiveExportPath,
+ llvm::cl::desc(
+ "File path to create a parameter archive of splat values out of all "
+ "parameter backed globals."),
+ llvm::cl::cat(category));
}
void SchedulingOptions::bindOptions(OptionsBinder &binder) {
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 32c22a4..120b81b 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -100,6 +100,10 @@
// Optional scope to use for the created parameter archive.
std::string parameterExportScope = "";
+ // File path to create a splat parameter archive out of all parameters in the
+ // module.
+ std::string splatParameterArchiveExportPath = "";
+
// Minimum size of constants to export as parameters.
int64_t minimumParameterExportSize = 256;
diff --git a/tests/e2e/parameters/BUILD.bazel b/tests/e2e/parameters/BUILD.bazel
index 5d16826..81e36b1 100644
--- a/tests/e2e/parameters/BUILD.bazel
+++ b/tests/e2e/parameters/BUILD.bazel
@@ -17,6 +17,7 @@
srcs = enforce_glob(
[
"export_parameters.mlir",
+ "generate_splat_archive.mlir",
],
include = ["*.mlir"],
),
diff --git a/tests/e2e/parameters/CMakeLists.txt b/tests/e2e/parameters/CMakeLists.txt
index 6c4f504..61ad996 100644
--- a/tests/e2e/parameters/CMakeLists.txt
+++ b/tests/e2e/parameters/CMakeLists.txt
@@ -16,6 +16,7 @@
lit
SRCS
"export_parameters.mlir"
+ "generate_splat_archive.mlir"
TOOLS
FileCheck
iree-compile
diff --git a/tests/e2e/parameters/generate_splat_archive.mlir b/tests/e2e/parameters/generate_splat_archive.mlir
new file mode 100644
index 0000000..85c160d
--- /dev/null
+++ b/tests/e2e/parameters/generate_splat_archive.mlir
@@ -0,0 +1,31 @@
+module @parameter_example {
+ util.global private @array_global_0 = #stream.parameter.named<"model"::"global_0"> : tensor<1x2xi32>
+ util.global private @dense_global_1 = #stream.parameter.named<"model"::"global_1"> : tensor<2x2xi32>
+ util.global private @dense_global_2 = #stream.parameter.named<"model"::"global_2"> : tensor<1x2xi32>
+ util.global private @dense_global_3 = #stream.parameter.named<"model"::"global_3"> : tensor<2x2xi32>
+ func.func @forward(%arg0: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %cst = arith.constant 0 : i32
+ %3 = util.global.load @array_global_0 : tensor<1x2xi32>
+ %4 = util.global.load @dense_global_1 : tensor<2x2xi32>
+ %5 = util.global.load @dense_global_2 : tensor<1x2xi32>
+ %6 = util.global.load @dense_global_3 : tensor<2x2xi32>
+ %empty = tensor.empty() : tensor<1x2xi32>
+ %fill = linalg.fill ins(%cst : i32) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %8 = linalg.matmul ins(%arg0, %6 : tensor<1x2xi32>, tensor<2x2xi32>) outs(%fill : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %10 = linalg.add ins(%8, %5 : tensor<1x2xi32>, tensor<1x2xi32>) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %12 = linalg.matmul ins(%10, %4 : tensor<1x2xi32>, tensor<2x2xi32>) outs(%fill : tensor<1x2xi32>) -> tensor<1x2xi32>
+ %14 = linalg.add ins(%12, %3 : tensor<1x2xi32>, tensor<1x2xi32>) outs(%empty : tensor<1x2xi32>) -> tensor<1x2xi32>
+ return %14 : tensor<1x2xi32>
+ }
+}
+
+// RUN: iree-compile %s \
+// RUN: --iree-hal-target-backends=vmvx \
+// RUN: --iree-opt-splat-parameter-archive-export-file=%t.irpa | \
+// RUN: iree-run-module --device=local-task --module=- \
+// RUN: --input=1x2xi32=1 \
+// RUN: --parameters=model=%t.irpa \
+// RUN: --function=forward | FileCheck %s
+
+// CHECK-LABEL: EXEC @forward
+// CHECK: 1x2xi32=[0 0]