Adding parameters as a concept to stream/hal/tooling. (#15104)
Parameters are externalized storage for resources that are
asynchronously accessible and device-aware. Parameters can be read or
written on the same device timelines as the operations that consume or
produce them and with locality pinning to ensure memory doesn't need to
move. Parameters are referenced by an optional scope (a file name, a
model name, whatever) and a unique key within that scope, with the scope
being strongly recommended as a way to distinguish sets of parameters
that may exist when multiple model parts are compiled together and would
otherwise collide.
Parameters are provided to programs by a virtual interface and can
support shared parameters (same storage used in multiple contexts, or
outliving a single instantiation in a context), in-memory caches,
memory-mapped files (including directly using the mapped memory for
execution when devices support it), iree_hal_file_t usage for
device-supported I/O, and parameter subsetting for things like runtime
sharding. A basic file cache is implemented to allow for programs to
decide when and where they want to use the parameters without needing to
have bound them to devices at startup time.
Alongside read(+load) and write operations gather and scatter allow for
batching of large numbers of reads and writes into/from single buffers.
For parameter providers that can batch operations this allows for a
handful (~1-4) of calls out to perform many more operations
(~thousands). Modeling the gather/scatter also gives us a point where we
could extract the mapping and use it to repack files/defrag memory in
the future.
Parameters are currently defined by the `#stream.parameter.named`
attribute which specifies an optional parameter scope and a scope-unique
key for the parameter along with its logical type. Today these are
intended to be used as default values on global ops (mutable or
immutable) and are hackily processed as if they were constants. Future
changes will allow parameter mutation and storage but what's present
should be enough for inference and training parameter initialization.
Example parameter (here a tensor but parameters can be other types in
the future to act as bags of bits):
```mlir
util.global private @"model.layer-1.kernel" = #stream.parameter.named<"mnist"::"model.layer-1.kernel"> : tensor<784x128xf32>
```
Parameters can optionally have a subrange specified indicating that the
logical tensor is a block of some larger storage. When sharding this can
be used to have an individual shard load a subset of the parameter data:
```mlir
util.global private @"model.layer-1.kernel-shard-0" = #stream.parameter.named<"mnist"::"model.layer-1.kernel", {offset = 0}> : tensor<392x128xf32>
util.global private @"model.layer-1.kernel-shard-1" = #stream.parameter.named<"mnist"::"model.layer-1.kernel", {offset = 200704}> : tensor<392x128xf32>
```
In this initial implementation we err on the side of optimizing for
discrete memory devices (GPUs/etc) by emitting gathers of all
parameters. On unified memory systems where we can zero-copy import
parameters into device memory this is wasteful but it ensures proper
alignment/packing/limited runtime overheads. Setting the resource memory
model to `unified` via the `#stream.resource_config` attribute (helper
flag `--iree-stream-resource-memory-model=unified`) will change instead
to aliasing parameter memory where possible at the cost of increased
runtime overhead. Future changes will connect the resource memory model
to those of the devices under compilation and allow for heterogenous
deployments to treat parameters used exclusively on different devices in
whatever way is best for that device.
Basic tooling support for read-only parameters has been added for
testing by allowing parameter files to be specified on the command line:
```
$ iree-run-module \
--parameter_mode=mmap \
--parameters=some_scope=some/file0.safetensors \
--parameters=other_scope=some/file1.gguf \
--module=...
```
Currently parameters are only usable from the full HAL implementation
and not the inline HAL - the parameter file format and index code was
kept portable such that it could be reused for a lighter-weight feature
set if we wanted to support parameters in the inline HAL but given that
cases where the inline HAL is interesting are usually small models on
tiny systems where optimization of parameters is critical to
memory/performance I haven't bothered here.
Since all parameter file formats are terrible a new parameter file
format that is less terrible for our uses will be introduced in future
changes. It's still experimental and not fully wired up but will be
something we can convert other formats into for optimize use as both
immutable constant and mutable variable storage in our tools when direct
compatibility with existing frameworks is not required without
conversion steps.
The `iree-dump-parameters` tool can be used to inspect any of the
parameter file formats the tooling can load and extract individual
parameters. It indexes parameters using the same flags as the rest of
the tooling so it can also be useful to see what parameters are actually
available for use without trial and error in other tools. Example
output:
```
$ ../iree-build/tools/iree-dump-parameters.exe --parameters=a=tools/test/parameters_a.safetensors --parameters=runtime/src/iree/io/formats/gguf/testdata/multiple.gguf --extract=a::a0=a0.bin --extract=tensor0=tensor0.bin
//===--------------------------------------------------------------------------------------------------------------===//
// Parameter scope `a` (2 entries, 64 total bytes)
//===------------+------------------+------------------+-----------------------------------------------------------===//
// Start | End | Length | Key
//---------------+------------------+------------------+--------------------------------------------------------------//
120 | 152 | 32 | `a0`
152 | 184 | 32 | `a1`
//===--------------------------------------------------------------------------------------------------------------===//
// Parameter scope `` (3 entries, 72 total bytes)
//===------------+------------------+------------------+-----------------------------------------------------------===//
// Start | End | Length | Key
//---------------+------------------+------------------+--------------------------------------------------------------//
448 | 464 | 16 | `tensor0`
512 | 520 | 8 | `tensor1`
576 | 624 | 48 | `tensor2`
Extracting parameter `a::a0` (32b) to `a0.bin`...
Extracting parameter `tensor0` (16b) to `tensor0.bin`...
```
Progress on #14987.
---------
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
index 52653fb..c4bd9ed 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
@@ -21,6 +21,7 @@
"Patterns.h",
],
deps = [
+ ":Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
@@ -36,3 +37,24 @@
"@llvm-project//mlir:Transforms",
],
)
+
+iree_compiler_cc_library(
+ name = "Utils",
+ srcs = [
+ "Utils.cpp",
+ ],
+ hdrs = [
+ "Utils.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SCFDialect",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
index 92dd0c6..77894b7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
@@ -18,6 +18,7 @@
SRCS
"Patterns.cpp"
DEPS
+ ::Utils
LLVMSupport
MLIRArithDialect
MLIRFuncDialect
@@ -34,4 +35,24 @@
PUBLIC
)
+iree_cc_library(
+ NAME
+ Utils
+ HDRS
+ "Utils.h"
+ SRCS
+ "Utils.cpp"
+ DEPS
+ LLVMSupport
+ MLIRArithDialect
+ MLIRIR
+ MLIRSCFDialect
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 46e21e6..22bd835 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.h"
+#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
@@ -13,7 +14,6 @@
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -23,315 +23,63 @@
namespace mlir {
namespace iree_compiler {
-static llvm::cl::opt<bool> clExternalResourcesMappable(
- "iree-stream-external-resources-mappable",
- llvm::cl::desc("Allocates external resources as host-visible and mappable. "
- "This can degrade performance and introduce allocation "
- "overhead and staging buffers for readback on the host "
- "should be managed by the calling application instead."),
- llvm::cl::init(false));
-
namespace {
-static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
- // TODO(benvanik): make this do multi-device lookup and other fancy things.
- auto lookupOp = builder.create<IREE::HAL::ExSharedDeviceOp>(op->getLoc());
- return lookupOp.getResult();
-}
-
// Returns the device queue affinity mask indicating which device queues the
// operations are allowed to execute on.
-static Value buildQueueAffinityMaskFor(Operation *op, Value device,
- OpBuilder &builder) {
+static Value buildQueueAffinityMask(Location loc,
+ IREE::Stream::AffinityAttr affinityAttr,
+ Value device, OpBuilder &builder) {
// Try to find a specified affinity. This may be on the op provided or one of
// its parent regions.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
if (auto queueAffinityAttr =
llvm::dyn_cast_if_present<IREE::HAL::AffinityQueueAttr>(
affinityAttr)) {
return builder.create<arith::ConstantIntOp>(
- op->getLoc(), queueAffinityAttr.getMask(), 64);
+ loc, queueAffinityAttr.getMask(), 64);
}
// No affinity specified; use default (any) affinity.
- return builder.create<arith::ConstantIntOp>(op->getLoc(), -1, 64);
+ return builder.create<arith::ConstantIntOp>(loc, -1, 64);
}
-static std::tuple<Value, Value>
-lookupDeviceAndQueueAffinityFor(Operation *op, OpBuilder &builder) {
- // NOTE: we have this combined method so that we can reuse any expensive
- // lookups we need to do. Today we aren't duplicating the lookups and don't
- // bother.
+struct ContextResolveOpPattern
+ : public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ContextResolveOp resolveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultTypes = llvm::to_vector(resolveOp.getResultTypes());
+ assert(!resultTypes.empty() && "must have at least one result");
- // Get a device handle used to create resources and schedule work.
- // It may be shared across many mutually-exclusive devices at runtime.
- Value device = lookupDeviceFor(op, builder);
+ // TODO(benvanik): make this do multi-device lookup and other fancy things.
+ Value device =
+ rewriter.create<IREE::HAL::ExSharedDeviceOp>(resolveOp.getLoc());
- // Derive the queue affinity mask from the op and device combination.
- Value queueAffinity = buildQueueAffinityMaskFor(op, device, builder);
-
- return std::make_tuple(device, queueAffinity);
-}
-
-static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
- auto device = lookupDeviceFor(op, builder);
- auto allocatorOp =
- builder.create<IREE::HAL::DeviceAllocatorOp>(op->getLoc(), device);
- return allocatorOp.getResult();
-}
-
-static std::tuple<Value, Value>
-lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) {
- // NOTE: we have this combined method so that we can reuse any expensive
- // lookups we need to do. Today we aren't duplicating the lookups and don't
- // bother.
-
- // Get a device handle used to create resources and schedule work.
- // It may be shared across many mutually-exclusive devices at runtime.
- Value device = lookupDeviceFor(op, builder);
-
- // Each device has a single allocator that may itself present multiple.
- Value allocator =
- builder.create<IREE::HAL::DeviceAllocatorOp>(op->getLoc(), device);
-
- // Derive the queue affinity mask from the op and device combination.
- Value queueAffinity = buildQueueAffinityMaskFor(op, device, builder);
-
- return std::make_tuple(allocator, queueAffinity);
-}
-
-// Returns the |timepointFence| or a util.null.
-static Value getOrCreateWaitFence(Location loc, Value timepointFence,
- OpBuilder &builder) {
- if (timepointFence)
- return timepointFence;
- return builder.create<IREE::Util::NullOp>(
- loc, builder.getType<IREE::HAL::FenceType>());
-}
-
-// Finds a !hal.fence bound to |timepoint| via a chain op and returns it if
-// it is usable at the builder insertion point. The chain ops is only used if
-// it is the only consumer of the timepoint and it is removed upon return.
-static Value consumeBoundFence(Value timepoint,
- ConversionPatternRewriter &rewriter) {
- // Must only have one use. We can't consume a fence multiple times.
- if (!timepoint.hasOneUse())
- return nullptr; // >1 use
-
- // The use must be an export to a fence.
- auto chainOp = dyn_cast<IREE::Stream::TimepointChainExternalOp>(
- *timepoint.getUsers().begin());
- if (!chainOp)
- return nullptr; // non-export use
- assert(!chainOp.getExternalValues().empty());
- auto fence = chainOp.getExternalValues().front();
- if (!fence || !llvm::isa<IREE::HAL::FenceType>(fence.getType()))
- return nullptr;
-
- // Try really hard to figure out if the fence can be used. A larger analysis
- // pass running prior to conversion that did some code motion could help
- // ensure the fence SSA value is usable in the places it is needed - for now
- // we just do this local check that satisfies most common programs today. IPO
- // would do something like add the fence as an argument to function calls so
- // that the functions could consume it but inlining is pretty aggressive now.
- if (!IREE::Util::isValueUsableForOp(fence, rewriter.getBlock(),
- rewriter.getInsertionPoint())) {
- return nullptr; // unusable
- }
-
- // Consume the op by erasing it.
- rewriter.eraseOp(chainOp);
-
- return fence; // usable
-}
-
-// Returns the a new fence for |timepoint| or an existing fence if one was
-// associated with an external fence. Returns util.null if no one observes the
-// fence.
-static Value getOrCreateSignalFence(Location loc, Value device, Value timepoint,
- ConversionPatternRewriter &rewriter) {
- // Check to see if anyone is consuming the timepoint - if not then we don't
- // need create a fence.
- if (timepoint.use_empty()) {
- return rewriter.create<IREE::Util::NullOp>(
- loc, rewriter.getType<IREE::HAL::FenceType>());
- }
-
- // Check to see if the timepoint is associated with a fence. In common cases
- // when along ABI boundaries we can usually find an association.
- auto fence = consumeBoundFence(timepoint, rewriter);
- if (fence)
- return fence;
-
- // Create a new fence.
- return rewriter.create<IREE::HAL::FenceCreateOp>(
- loc, rewriter.getType<IREE::HAL::FenceType>(), device,
- IREE::HAL::FenceFlagBitfield::None);
-}
-
-// Scans all of the stream.cmd.* ops in the region to derive a command category.
-static IREE::HAL::CommandCategoryBitfield
-deriveCommandCategories(Region ®ion) {
- auto bits = IREE::HAL::CommandCategoryBitfield::None;
- for (auto &block : region) {
- for (auto &op : block) {
- if (isa<IREE::Stream::CmdCollectiveOp>(op) ||
- isa<IREE::Stream::CmdCallOp>(op)) {
- // Calls may do anything and collectives may be implemented as either
- // transfers or dispatches.
- bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch |
- IREE::HAL::CommandCategoryBitfield::Transfer;
- } else if (isa<IREE::Stream::CmdDispatchOp>(op)) {
- bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch;
- } else {
- bits = bits | IREE::HAL::CommandCategoryBitfield::Transfer;
- }
- for (auto &nestedRegion : op.getRegions()) {
- bits = bits | deriveCommandCategories(nestedRegion);
- }
- }
- }
- return bits;
-}
-
-// Maps a resource type to the corresponding HAL memory types and buffer usage.
-// This will fail if the resource type is not directly mappable to HAL bits.
-// The bits set here are those that must be set for the buffer to be used as the
-// buffer within the program with its defined resource lifetime.
-static LogicalResult
-deriveRequiredResourceBufferBits(Location loc,
- IREE::Stream::ResourceType resourceType,
- IREE::HAL::MemoryTypeBitfield &memoryTypes,
- IREE::HAL::BufferUsageBitfield &bufferUsage) {
- memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
- bufferUsage = IREE::HAL::BufferUsageBitfield::None;
- switch (resourceType.getLifetime()) {
- default:
- return mlir::emitError(loc)
- << "unsupported resource lifetime: "
- << IREE::Stream::stringifyLifetime(resourceType.getLifetime());
- case IREE::Stream::Lifetime::Constant:
- // Device local; copies required to get into external resources.
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
- bufferUsage =
- bufferUsage | IREE::HAL::BufferUsageBitfield::SharingImmutable;
- break;
- case IREE::Stream::Lifetime::Variable:
- // Device local; copies required to get into external resources.
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
- break;
- case IREE::Stream::Lifetime::External:
- // We only require device-visible for external buffers (as we don't today
- // do anything else with them on the host). They may be mappable for user
- // convenience. Ideally they would have been placed in device-local memory
- // but so long as they are device visible the program will execute
- // correctly.
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceVisible;
- break;
- case IREE::Stream::Lifetime::Staging:
- // Host local; copies required to get into device resources.
- // We could vary this based on staging usage (upload/download) by
- // making it device-local|host-visible, but host-local means we have
- // a better chance of mapping it during uploads.
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostLocal |
- IREE::HAL::MemoryTypeBitfield::DeviceVisible;
- bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
- IREE::HAL::BufferUsageBitfield::Mapping;
- break;
- case IREE::Stream::Lifetime::Transient:
- // Device local; copies required to get into external resources.
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
- break;
- }
-
- // TODO(benvanik): refine usage based on analysis.
- bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
- IREE::HAL::BufferUsageBitfield::DispatchStorage;
-
- return success();
-}
-
-// Maps a resource type to the corresponding HAL memory types and buffer usage.
-// This will fail if the resource type is not directly mappable to HAL bits.
-// The bits set here represent the superset of required and allowed bits and
-// are useful for providing buffers back to users via the ABI that may need to
-// be used for more than just what the internal program requires.
-static LogicalResult
-deriveAllowedResourceBufferBits(Location loc,
- IREE::Stream::ResourceType resourceType,
- IREE::HAL::MemoryTypeBitfield &memoryTypes,
- IREE::HAL::BufferUsageBitfield &bufferUsage) {
- memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
- bufferUsage = IREE::HAL::BufferUsageBitfield::None;
- if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
- bufferUsage))) {
- return failure();
- }
- switch (resourceType.getLifetime()) {
- default:
- break;
- case IREE::Stream::Lifetime::External:
- if (clExternalResourcesMappable) {
- // #yolo; these come from/go to outside the program.
- // Today we assume they are device-local|host-visible just for
- // practical purposes but that does not have to be true. We really
- // want this to be something we analyze and handle on the edges
- // (transferring devices/etc if needed).
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
- IREE::HAL::MemoryTypeBitfield::HostVisible;
- // NOTE: we may not map it but users may after they get them back.
- // Another reason we should annotate this - having a buffer be
- // mappable is potentially expensive (may get a 2nd copy in memory!).
- bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+ SmallVector<Value> results;
+ if (resultTypes[0].isa<IREE::HAL::DeviceType>()) {
+ results.push_back(device);
+ } else if (resultTypes[0].isa<IREE::HAL::AllocatorType>()) {
+ results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
+ resolveOp.getLoc(), device));
} else {
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ return rewriter.notifyMatchFailure(
+ resolveOp, "unrecognized context resolve types for a HAL target");
}
- break;
+ if (resultTypes.size() > 1) {
+ if (resultTypes[1].isa<IntegerType>()) {
+ results.push_back(buildQueueAffinityMask(
+ resolveOp.getLoc(), resolveOp.getAffinityAttr(), device, rewriter));
+ } else {
+ return rewriter.notifyMatchFailure(
+ resolveOp,
+ "unrecognized context resolve types for a HAL target (extended)");
+ }
+ }
+
+ rewriter.replaceOp(resolveOp, results);
+ return success();
}
- return success();
-}
-
-class StreamConversionMapping {
-public:
- // Maps the stream dialect |executeOp| to the hal dialect |commandBuffer|
- // value used during recording. Patterns can use this to find the SSA value
- // they need to make hal.command_buffer.* ops.
- void mapCommandBuffer(IREE::Stream::CmdExecuteOp executeOp,
- Value commandBuffer) {
- assert(commandBuffers.insert(std::make_pair(executeOp, commandBuffer))
- .second &&
- "multiple command buffers cannot be registered for the same op");
-
- // Map all ops nested within the command buffer so we can query later.
- executeOp.walk([&](Operation *op) {
- commandBuffers.insert(std::make_pair(op, commandBuffer));
- return WalkResult::advance();
- });
- }
-
- // Looks up a mapped command buffer SSA value that can be used by the given
- // stream.cmd.* op.
- Value lookupCommandBufferFor(Operation *cmdOp) const {
- auto it = commandBuffers.find(cmdOp);
- assert(it != commandBuffers.end() &&
- "command buffer must have been registered during conversion");
- return it->second;
- }
-
-private:
- // Ops within stream.cmd.execute ops -> !hal.command_buffer.
- DenseMap<Operation *, Value> commandBuffers;
-};
-
-template <typename OpT>
-struct StreamConversionPattern : public OpConversionPattern<OpT> {
- StreamConversionPattern(std::shared_ptr<StreamConversionMapping> mapping,
- TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpConversionPattern<OpT>(typeConverter, context, benefit),
- mapping(std::move(mapping)) {}
-
- std::shared_ptr<StreamConversionMapping> mapping;
};
struct ResourceAllocOpPattern
@@ -1549,6 +1297,8 @@
auto mapping = std::make_shared<StreamConversionMapping>();
+ patterns.insert<ContextResolveOpPattern>(mapping, typeConverter, context);
+
patterns.insert<ResourceAllocOpPattern, ResourceAllocaOpPattern,
ResourceDeallocaOpPattern, ResourceSizeOpPattern,
ResourceTryMapOpPattern, ResourceLoadOpPattern,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
new file mode 100644
index 0000000..8b628da
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
@@ -0,0 +1,267 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
+
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/CommandLine.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+static llvm::cl::opt<bool> clExternalResourcesMappable(
+ "iree-stream-external-resources-mappable",
+ llvm::cl::desc("Allocates external resources as host-visible and mappable. "
+ "This can degrade performance and introduce allocation "
+ "overhead and staging buffers for readback on the host "
+ "should be managed by the calling application instead."),
+ llvm::cl::init(false));
+
+namespace mlir::iree_compiler {
+
+Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
+ op->getLoc(),
+ TypeRange{
+ builder.getType<IREE::HAL::DeviceType>(),
+ },
+ affinityAttr);
+ return resolveOp.getResult(0);
+}
+
+std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(Operation *op,
+ OpBuilder &builder) {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
+ op->getLoc(),
+ TypeRange{
+ builder.getType<IREE::HAL::DeviceType>(),
+ builder.getI64Type(),
+ },
+ affinityAttr);
+ return std::make_tuple(resolveOp.getResult(0), resolveOp.getResult(1));
+}
+
+Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
+ op->getLoc(),
+ TypeRange{
+ builder.getType<IREE::HAL::AllocatorType>(),
+ },
+ affinityAttr);
+ return resolveOp.getResult(0);
+}
+
+std::tuple<Value, Value>
+lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
+ op->getLoc(),
+ TypeRange{
+ builder.getType<IREE::HAL::AllocatorType>(),
+ builder.getI64Type(),
+ },
+ affinityAttr);
+ return std::make_tuple(resolveOp.getResult(0), resolveOp.getResult(1));
+}
+
+Value getOrCreateWaitFence(Location loc, Value timepointFence,
+ PatternRewriter &rewriter) {
+ if (timepointFence)
+ return timepointFence;
+ return rewriter.create<IREE::Util::NullOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>());
+}
+
+// Finds a !hal.fence bound to |timepoint| via a chain op and returns it if
+// it is usable at the builder insertion point. The chain ops is only used if
+// it is the only consumer of the timepoint and it is removed upon return.
+static Value consumeBoundFence(Value timepoint, PatternRewriter &rewriter) {
+ // Must only have one use. We can't consume a fence multiple times.
+ if (!timepoint.hasOneUse())
+ return nullptr; // >1 use
+
+ // The use must be an export to a fence.
+ auto chainOp = dyn_cast<IREE::Stream::TimepointChainExternalOp>(
+ *timepoint.getUsers().begin());
+ if (!chainOp)
+ return nullptr; // non-export use
+ assert(!chainOp.getExternalValues().empty());
+ auto fence = chainOp.getExternalValues().front();
+ if (!fence || !llvm::isa<IREE::HAL::FenceType>(fence.getType()))
+ return nullptr;
+
+ // Try really hard to figure out if the fence can be used. A larger analysis
+ // pass running prior to conversion that did some code motion could help
+ // ensure the fence SSA value is usable in the places it is needed - for now
+ // we just do this local check that satisfies most common programs today. IPO
+ // would do something like add the fence as an argument to function calls so
+ // that the functions could consume it but inlining is pretty aggressive now.
+ if (!IREE::Util::isValueUsableForOp(fence, rewriter.getBlock(),
+ rewriter.getInsertionPoint())) {
+ return nullptr; // unusable
+ }
+
+ // Consume the op by erasing it.
+ rewriter.eraseOp(chainOp);
+
+ return fence; // usable
+}
+
+Value getOrCreateSignalFence(Location loc, Value device, Value timepoint,
+ PatternRewriter &rewriter) {
+ // Check to see if anyone is consuming the timepoint - if not then we don't
+ // need create a fence.
+ if (timepoint.use_empty()) {
+ return rewriter.create<IREE::Util::NullOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>());
+ }
+
+ // Check to see if the timepoint is associated with a fence. In common cases
+ // when along ABI boundaries we can usually find an association.
+ auto fence = consumeBoundFence(timepoint, rewriter);
+ if (fence)
+ return fence;
+
+ // Create a new fence.
+ return rewriter.create<IREE::HAL::FenceCreateOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>(), device,
+ IREE::HAL::FenceFlagBitfield::None);
+}
+
+IREE::HAL::CommandCategoryBitfield deriveCommandCategories(Region ®ion) {
+ auto bits = IREE::HAL::CommandCategoryBitfield::None;
+ for (auto &block : region) {
+ for (auto &op : block) {
+ if (isa<IREE::Stream::CmdCollectiveOp>(op) ||
+ isa<IREE::Stream::CmdCallOp>(op)) {
+ // Calls may do anything and collectives may be implemented as either
+ // transfers or dispatches.
+ bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch |
+ IREE::HAL::CommandCategoryBitfield::Transfer;
+ } else if (isa<IREE::Stream::CmdDispatchOp>(op)) {
+ bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch;
+ } else {
+ bits = bits | IREE::HAL::CommandCategoryBitfield::Transfer;
+ }
+ for (auto &nestedRegion : op.getRegions()) {
+ bits = bits | deriveCommandCategories(nestedRegion);
+ }
+ }
+ }
+ return bits;
+}
+
+LogicalResult
+deriveRequiredResourceBufferBits(Location loc,
+ IREE::Stream::ResourceType resourceType,
+ IREE::HAL::MemoryTypeBitfield &memoryTypes,
+ IREE::HAL::BufferUsageBitfield &bufferUsage) {
+ memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
+ bufferUsage = IREE::HAL::BufferUsageBitfield::None;
+ switch (resourceType.getLifetime()) {
+ default:
+ return mlir::emitError(loc)
+ << "unsupported resource lifetime: "
+ << IREE::Stream::stringifyLifetime(resourceType.getLifetime());
+ case IREE::Stream::Lifetime::Constant:
+ // Device local; copies required to get into external resources.
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ bufferUsage =
+ bufferUsage | IREE::HAL::BufferUsageBitfield::SharingImmutable;
+ break;
+ case IREE::Stream::Lifetime::Variable:
+ // Device local; copies required to get into external resources.
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ break;
+ case IREE::Stream::Lifetime::External:
+ // We only require device-visible for external buffers (as we don't today
+ // do anything else with them on the host). They may be mappable for user
+ // convenience. Ideally they would have been placed in device-local memory
+ // but so long as they are device visible the program will execute
+ // correctly.
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceVisible;
+ break;
+ case IREE::Stream::Lifetime::Staging:
+ // Host local; copies required to get into device resources.
+ // We could vary this based on staging usage (upload/download) by
+ // making it device-local|host-visible, but host-local means we have
+ // a better chance of mapping it during uploads.
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostLocal |
+ IREE::HAL::MemoryTypeBitfield::DeviceVisible;
+ bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
+ IREE::HAL::BufferUsageBitfield::Mapping;
+ break;
+ case IREE::Stream::Lifetime::Transient:
+ // Device local; copies required to get into external resources.
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ break;
+ }
+
+ // TODO(benvanik): refine usage based on analysis.
+ bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
+ IREE::HAL::BufferUsageBitfield::DispatchStorage;
+
+ return success();
+}
+
+LogicalResult
+deriveAllowedResourceBufferBits(Location loc,
+ IREE::Stream::ResourceType resourceType,
+ IREE::HAL::MemoryTypeBitfield &memoryTypes,
+ IREE::HAL::BufferUsageBitfield &bufferUsage) {
+ memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
+ bufferUsage = IREE::HAL::BufferUsageBitfield::None;
+ if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
+ bufferUsage))) {
+ return failure();
+ }
+ switch (resourceType.getLifetime()) {
+ default:
+ break;
+ case IREE::Stream::Lifetime::External:
+ if (clExternalResourcesMappable) {
+ // #yolo; these come from/go to outside the program.
+ // Today we assume they are device-local|host-visible just for
+ // practical purposes but that does not have to be true. We really
+ // want this to be something we analyze and handle on the edges
+ // (transferring devices/etc if needed).
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
+ IREE::HAL::MemoryTypeBitfield::HostVisible;
+ // NOTE: we may not map it but users may after they get them back.
+ // Another reason we should annotate this - having a buffer be
+ // mappable is potentially expensive (may get a 2nd copy in memory!).
+ bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+ } else {
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ }
+ break;
+ }
+ return success();
+}
+
+void StreamConversionMapping::mapCommandBuffer(
+ IREE::Stream::CmdExecuteOp executeOp, Value commandBuffer) {
+ assert(
+ commandBuffers.insert(std::make_pair(executeOp, commandBuffer)).second &&
+ "multiple command buffers cannot be registered for the same op");
+
+ // Map all ops nested within the command buffer so we can query later.
+ executeOp.walk([&](Operation *op) {
+ commandBuffers.insert(std::make_pair(op, commandBuffer));
+ return WalkResult::advance();
+ });
+}
+
+Value StreamConversionMapping::lookupCommandBufferFor(Operation *cmdOp) const {
+ auto it = commandBuffers.find(cmdOp);
+ assert(it != commandBuffers.end() &&
+ "command buffer must have been registered during conversion");
+ return it->second;
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h
new file mode 100644
index 0000000..46d969d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h
@@ -0,0 +1,100 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_CONVERSION_STREAMTOHAL_UTILS_H_
+#define IREE_COMPILER_DIALECT_HAL_CONVERSION_STREAMTOHAL_UTILS_H_
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir::iree_compiler {
+
+// Returns a !hal.device for the affinity specified on |op|.
+Value lookupDeviceFor(Operation *op, OpBuilder &builder);
+
+// Returns a !hal.device and queue affinity i64 for the affinity specified on
+// |op|.
+std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(Operation *op,
+ OpBuilder &builder);
+
+// Returns the !hal.allocator for the affinity specified on |op|.
+Value lookupAllocatorFor(Operation *op, OpBuilder &builder);
+
+// Returns a !hal.allocator and queue affinity i64 for the affinity specified on
+// |op|.
+std::tuple<Value, Value> lookupAllocatorAndQueueAffinityFor(Operation *op,
+ OpBuilder &builder);
+
+// Returns the |timepointFence| or a util.null if the wait is to be ignored.
+Value getOrCreateWaitFence(Location loc, Value timepointFence,
+ PatternRewriter &rewriter);
+
+// Returns the a new fence for |timepoint| or an existing fence if one was
+// associated with an external fence. Returns util.null if no one observes the
+// fence.
+Value getOrCreateSignalFence(Location loc, Value device, Value timepoint,
+ PatternRewriter &rewriter);
+
+// Scans all of the stream.cmd.* ops in the region to derive a command category.
+IREE::HAL::CommandCategoryBitfield deriveCommandCategories(Region ®ion);
+
+// Maps a resource type to the corresponding HAL memory types and buffer usage.
+// This will fail if the resource type is not directly mappable to HAL bits.
+// The bits set here are those that must be set for the buffer to be used as the
+// buffer within the program with its defined resource lifetime.
+LogicalResult
+deriveRequiredResourceBufferBits(Location loc,
+ IREE::Stream::ResourceType resourceType,
+ IREE::HAL::MemoryTypeBitfield &memoryTypes,
+ IREE::HAL::BufferUsageBitfield &bufferUsage);
+
+// Maps a resource type to the corresponding HAL memory types and buffer usage.
+// This will fail if the resource type is not directly mappable to HAL bits.
+// The bits set here represent the superset of required and allowed bits and
+// are useful for providing buffers back to users via the ABI that may need to
+// be used for more than just what the internal program requires.
+LogicalResult
+deriveAllowedResourceBufferBits(Location loc,
+ IREE::Stream::ResourceType resourceType,
+ IREE::HAL::MemoryTypeBitfield &memoryTypes,
+ IREE::HAL::BufferUsageBitfield &bufferUsage);
+
+class StreamConversionMapping {
+public:
+ // Maps the stream dialect |executeOp| to the hal dialect |commandBuffer|
+ // value used during recording. Patterns can use this to find the SSA value
+ // they need to make hal.command_buffer.* ops.
+ void mapCommandBuffer(IREE::Stream::CmdExecuteOp executeOp,
+ Value commandBuffer);
+
+ // Looks up a mapped command buffer SSA value that can be used by the given
+ // stream.cmd.* op.
+ Value lookupCommandBufferFor(Operation *cmdOp) const;
+
+private:
+ // Ops within stream.cmd.execute ops -> !hal.command_buffer.
+ DenseMap<Operation *, Value> commandBuffers;
+};
+
+template <typename OpT>
+struct StreamConversionPattern : public OpConversionPattern<OpT> {
+ StreamConversionPattern(std::shared_ptr<StreamConversionMapping> mapping,
+ TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<OpT>(typeConverter, context, benefit),
+ mapping(std::move(mapping)) {}
+
+ std::shared_ptr<StreamConversionMapping> mapping;
+};
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_STREAMTOHAL_UTILS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
index a8bb9f3..2d6f777 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
@@ -18,6 +18,7 @@
[
"channel_ops.mlir",
"cmd_ops.mlir",
+ "context_ops.mlir",
"debug_ops.mlir",
"file_ops.mlir",
"resource_ops.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
index fb7c40f..b273190 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"channel_ops.mlir"
"cmd_ops.mlir"
+ "context_ops.mlir"
"debug_ops.mlir"
"file_ops.mlir"
"resource_ops.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
new file mode 100644
index 0000000..659c11b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -0,0 +1,42 @@
+// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s
+
+// CHECK-LABEL: @contextResolveAllocator
+func.func @contextResolveAllocator() -> !hal.allocator {
+ // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
+ %allocator = stream.context.resolve : !hal.allocator
+ // CHECK: return %[[ALLOCATOR]]
+ return %allocator : !hal.allocator
+}
+
+// -----
+
+// CHECK-LABEL: @contextResolveDevice
+func.func @contextResolveDevice() -> !hal.device {
+ // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ %device = stream.context.resolve : !hal.device
+ // CHECK: return %[[DEVICE]]
+ return %device : !hal.device
+}
+
+// -----
+
+// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny
+func.func @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
+ %device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
+ // CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ return %device, %queue_affinity_any : !hal.device, i64
+}
+
+// -----
+
+// CHECK-LABEL: @contextResolveDeviceQueueAffinity45
+func.func @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
+ // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
+ %device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
+ // CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ return %device, %queue_affinity_45 : !hal.device, i64
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index 6b69403..dcb52fb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -56,6 +56,7 @@
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/IR:IOParametersDialect",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas/instruments",
"//runtime/src/iree/schemas/instruments:dispatch_def_c_fbs",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 81bee1b..4ceff9f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -67,6 +67,7 @@
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
+ iree::compiler::Modules::IO::Parameters::IR::IOParametersDialect
iree::compiler::Utils
iree::schemas::instruments
iree::schemas::instruments::dispatch_def_c_fbs
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
index 7809131..1ac52d3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -51,6 +52,9 @@
registry.insert<IREE::HAL::HALDialect>();
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
+
+ // TODO(benvanik): add a registration system for extra dialects?
+ registry.insert<IREE::IO::Parameters::IOParametersDialect>();
}
StringRef getArgument() const override { return "iree-hal-conversion"; }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
index 6c39a05..b36fc41 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -233,9 +233,8 @@
signaledResources.push_back(barrierOp.getResult());
signaledTimepoints.push_back(barrierOp.getResultTimepoint());
}
- Value joinedTimepoint =
- rewriter.createOrFold<IREE::Stream::TimepointJoinOp>(
- op.getLoc(), timepointType, signaledTimepoints);
+ Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join(
+ op.getLoc(), signaledTimepoints, rewriter);
rewriter.create<IREE::Stream::TimepointChainExternalOp>(
op.getLoc(), joinedTimepoint, ValueRange{adaptor.getSignalFence()},
/*affinity=*/nullptr);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
index cd0495d..dd86c21 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -292,6 +292,16 @@
let hasCustomAssemblyFormat = 1;
}
+def Stream_MemoryModel_Unified : I32EnumAttrCase<"Unified", 0>;
+def Stream_MemoryModel_Discrete : I32EnumAttrCase<"Discrete", 1>;
+def Stream_MemoryModelAttr :
+ I32EnumAttr<"MemoryModel", "stream resource memory model", [
+ Stream_MemoryModel_Unified,
+ Stream_MemoryModel_Discrete,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+}
+
def Stream_ResourceConfigAttr :
AttrDef<Stream_Dialect, "ResourceConfig", []> {
let mnemonic = "resource_config";
@@ -323,7 +333,9 @@
// Number of bits in `index` values as passed across device boundaries.
"int64_t":$indexBits,
// Fuses bindings that are mutable instead of leaving them split.
- "bool":$aliasMutableBindings
+ "bool":$aliasMutableBindings,
+ // Memory model used for host-device memory access.
+ "IREE::Stream::MemoryModel":$memoryModel
);
let valueType = NoneType;
@@ -365,6 +377,38 @@
"access array attribute"> {}
//===----------------------------------------------------------------------===//
+// Parameter storage attributes
+//===----------------------------------------------------------------------===//
+
+def Stream_NamedParameterAttr :
+ AttrDef<Stream_Dialect, "NamedParameter", [
+ TypedAttrInterface,
+ DeclareAttrInterfaceMethods<Util_SizedStorageAttr, [
+ "getStorageSize",
+ ]>,
+ ]> {
+ let mnemonic = "parameter.named";
+ let summary = [{named parameter referenced an optional scope and key}];
+ let description = [{
+ Species an externally-defined parameter that can be referenced by an
+ optional scope defining a set of parameters and a key uniquely identifying
+ the parameter within its scope.
+ }];
+ let parameters = (ins
+ AttributeSelfTypeParameter<"">:$type,
+ OptionalParameter<"StringAttr">:$scope,
+ AttrParameter<"StringAttr", "">:$key,
+ OptionalParameter<"DictionaryAttr">:$config
+ );
+ let assemblyFormat = [{
+ `<`
+ custom<ParameterReference>($scope, $key)
+ (`,` $config^)?
+ `>`
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Stream synchronization types
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 31783e6..7a760d1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -518,8 +518,7 @@
joinTimepoints.push_back(existingTimepoint);
}
llvm::append_range(joinTimepoints, newTimepoints);
- return builder.create<IREE::Stream::TimepointJoinOp>(
- loc, builder.getType<IREE::Stream::TimepointType>(), joinTimepoints);
+ return IREE::Stream::TimepointJoinOp::join(loc, joinTimepoints, builder);
}
// Elides waits that are known to be immediately resolved.
@@ -989,6 +988,189 @@
}
//===----------------------------------------------------------------------===//
+// stream.parameter.load
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldParameterLoadTargetSubview
+ : public OpRewritePattern<ParameterLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ParameterLoadOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadResult = op.getResult();
+ if (!loadResult.hasOneUse())
+ return failure();
+ Operation *user = *loadResult.getUsers().begin();
+
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+
+ auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
+ auto newResultSize = op.getResultSize();
+ if (auto subviewOp = dyn_cast<IREE::Stream::ResourceSubviewOp>(user)) {
+ auto viewSourceOffset = subviewOp.getSourceOffset();
+ auto viewResultSize = subviewOp.getResultSize();
+ if (IREE::Util::tryMoveProducerBefore(viewSourceOffset, op) &&
+ IREE::Util::tryMoveProducerBefore(viewResultSize, op)) {
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), newSourceOffset,
+ rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ subviewOp.getLoc(), rewriter.getI64Type(), viewSourceOffset));
+ newResultSize = viewResultSize;
+ rewriter.replaceAllUsesWith(subviewOp.getResult(), op.getResult());
+ needsUpdate = true;
+ }
+ }
+
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ op.getResultSizeMutable().assign(newResultSize);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ParameterLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<ParameterLoadOp>>(context);
+ results.insert<FoldParameterLoadTargetSubview>(context);
+ results.insert<ElideImmediateTimepointWait<ParameterLoadOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.read
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldParameterReadTargetSubview
+ : public OpRewritePattern<ParameterReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ParameterReadOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
+ auto newTargetResource = op.getTarget();
+ auto newTargetSize = op.getTargetSize();
+ auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
+ if (auto subviewOp = dyn_cast_or_null<IREE::Stream::ResourceSubviewOp>(
+ newTargetResource.getDefiningOp())) {
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), newSourceOffset,
+ rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ subviewOp.getLoc(), rewriter.getI64Type(),
+ subviewOp.getSourceOffset()));
+ newTargetResource = subviewOp.getSource();
+ newTargetSize = subviewOp.getSourceSize();
+ newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), subviewOp.getSourceOffset(), newTargetOffset);
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ op.getTargetMutable().assign(newTargetResource);
+ op.getTargetSizeMutable().assign(newTargetSize);
+ op.getTargetOffsetMutable().assign(newTargetOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ParameterReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<ParameterReadOp>>(context);
+ results.insert<FoldParameterReadTargetSubview>(context);
+ results.insert<ElideImmediateTimepointWait<ParameterReadOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.write
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldParameterWriteSourceSubview
+ : public OpRewritePattern<ParameterWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ParameterWriteOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceResource = op.getSource();
+ auto newSourceSize = op.getSourceSize();
+ auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
+ auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
+ if (auto subviewOp = dyn_cast_or_null<IREE::Stream::ResourceSubviewOp>(
+ newSourceResource.getDefiningOp())) {
+ newSourceResource = subviewOp.getSource();
+ newSourceSize = subviewOp.getSourceSize();
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), subviewOp.getSourceOffset(), newSourceOffset);
+ newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), newTargetOffset,
+ rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ subviewOp.getLoc(), rewriter.getI64Type(),
+ subviewOp.getSourceOffset()));
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceMutable().assign(newSourceResource);
+ op.getSourceSizeMutable().assign(newSourceSize);
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ op.getTargetOffsetMutable().assign(newTargetOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ParameterWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<ParameterWriteOp>>(context);
+ results.insert<FoldParameterWriteSourceSubview>(context);
+ results.insert<ElideImmediateTimepointWait<ParameterWriteOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.gather
+//===----------------------------------------------------------------------===//
+
+void ParameterGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<ParameterGatherOp>>(context);
+ results.insert<ElideImmediateTimepointWait<ParameterGatherOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.scatter
+//===----------------------------------------------------------------------===//
+
+void ParameterScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<ParameterScatterOp>>(context);
+ results.insert<ElideImmediateTimepointWait<ParameterScatterOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
// stream.file.read
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index b69246b..a2a46ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -391,6 +391,200 @@
}
//===----------------------------------------------------------------------===//
+// custom<ParameterGatherOperations>(
+// $source_scope, $source_keys, $source_offsets,
+// $target, type($target), $target_size, $target_offsets, $target_lengths)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterGatherOperations(
+ OpAsmParser &parser, StringAttr &sourceScopeAttr, ArrayAttr &sourceKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ OpAsmParser::UnresolvedOperand &target, Type &targetType,
+ OpAsmParser::UnresolvedOperand &targetSize,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetLengths) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> sourceKeyAttrs;
+ do {
+ StringAttr rowSourceScopeAttr;
+ StringAttr sourceKeyAttr;
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ OpAsmParser::UnresolvedOperand targetOffset;
+ OpAsmParser::UnresolvedOperand targetLength;
+ OpAsmParser::UnresolvedOperand rowTarget;
+ Type rowTargetType;
+ OpAsmParser::UnresolvedOperand rowTargetSize;
+ if (failed(parseParameterReference(parser, rowSourceScopeAttr,
+ sourceKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseRSquare()) || failed(parser.parseArrow()) ||
+ failed(parser.parseOperand(rowTarget)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(targetOffset)) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperand(targetLength)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(rowTargetType)) ||
+ failed(parser.parseLBrace()) ||
+ failed(parser.parseOperand(rowTargetSize)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ if (!targetType) {
+ sourceScopeAttr = rowSourceScopeAttr;
+ target = rowTarget;
+ targetType = rowTargetType;
+ targetSize = rowTargetSize;
+ } else if (rowSourceScopeAttr != sourceScopeAttr ||
+ rowTarget.name != target.name || rowTargetType != targetType ||
+ rowTargetSize.name != targetSize.name) {
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "each operation must use the same scope and target resource");
+ }
+ sourceKeyAttrs.push_back(sourceKeyAttr);
+ sourceOffsets.push_back(sourceOffset);
+ targetOffsets.push_back(targetOffset);
+ targetLengths.push_back(targetLength);
+ } while (succeeded(parser.parseOptionalComma()));
+ sourceKeysAttr = builder.getArrayAttr(sourceKeyAttrs);
+ return success();
+}
+
+static void printParameterGatherOperations(
+ OpAsmPrinter &p, Operation *op, StringAttr sourceScopeAttr,
+ ArrayAttr sourceKeysAttr, ValueRange sourceOffsets, Value target,
+ Type targetType, Value targetSize, ValueRange targetOffsets,
+ ValueRange targetLengths) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceKeysAttr.getAsRange<StringAttr>(), sourceOffsets,
+ targetOffsets, targetLengths),
+ [&](std::tuple<StringAttr, Value, Value, Value> it) {
+ auto [sourceKeyAttr, sourceOffset, targetOffset, targetLength] = it;
+ printParameterReference(p, op, sourceScopeAttr, sourceKeyAttr);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << "] -> ";
+ p.printOperand(target);
+ p << "[";
+ p.printOperand(targetOffset);
+ p << " for ";
+ p.printOperand(targetLength);
+ p << "] : ";
+ p.printType(targetType);
+ p << "{";
+ p.printOperand(targetSize);
+ p << "}";
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ParameterScatterOperations>(
+// $source, type($source), $source_size, $source_offsets, $source_lengths,
+// $target_scope, $target_keys, $target_offsets)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterScatterOperations(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &source,
+ Type &sourceType, OpAsmParser::UnresolvedOperand &sourceSize,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceLengths,
+ StringAttr &targetScopeAttr, ArrayAttr &targetKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetOffsets) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> targetKeyAttrs;
+ do {
+ OpAsmParser::UnresolvedOperand rowSource;
+ Type rowSourceType;
+ OpAsmParser::UnresolvedOperand rowSourceSize;
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ OpAsmParser::UnresolvedOperand sourceLength;
+ StringAttr rowTargetScopeAttr;
+ StringAttr targetKeyAttr;
+ OpAsmParser::UnresolvedOperand targetOffset;
+ if (failed(parser.parseOperand(rowSource)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperand(sourceLength)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(rowSourceType)) ||
+ failed(parser.parseLBrace()) ||
+ failed(parser.parseOperand(rowSourceSize)) ||
+ failed(parser.parseRBrace()) || failed(parser.parseArrow()) ||
+ failed(parseParameterReference(parser, rowTargetScopeAttr,
+ targetKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(targetOffset)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ if (!sourceType) {
+ source = rowSource;
+ sourceType = rowSourceType;
+ sourceSize = rowSourceSize;
+ targetScopeAttr = rowTargetScopeAttr;
+ } else if (rowSource.name != source.name || rowSourceType != sourceType ||
+ rowSourceSize.name != sourceSize.name ||
+ rowTargetScopeAttr != targetScopeAttr) {
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "each operation must use the same source resource and scope");
+ }
+ sourceOffsets.push_back(sourceOffset);
+ sourceLengths.push_back(sourceLength);
+ targetKeyAttrs.push_back(targetKeyAttr);
+ targetOffsets.push_back(targetOffset);
+ } while (succeeded(parser.parseOptionalComma()));
+ targetKeysAttr = builder.getArrayAttr(targetKeyAttrs);
+ return success();
+}
+
+static void printParameterScatterOperations(
+ OpAsmPrinter &p, Operation *op, Value source, Type sourceType,
+ Value sourceSize, ValueRange sourceOffsets, ValueRange sourceLengths,
+ StringAttr targetScopeAttr, ArrayAttr targetKeysAttr,
+ ValueRange targetOffsets) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceOffsets, sourceLengths,
+ targetKeysAttr.getAsRange<StringAttr>(), targetOffsets),
+ [&](std::tuple<Value, Value, StringAttr, Value> it) {
+ auto [sourceOffset, sourceLength, targetKeyAttr, targetOffset] = it;
+ p.printOperand(source);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << " for ";
+ p.printOperand(sourceLength);
+ p << "] : ";
+ p.printType(sourceType);
+ p << "{";
+ p.printOperand(sourceSize);
+ p << "} -> ";
+ printParameterReference(p, op, targetScopeAttr, targetKeyAttr);
+ p << "[";
+ p.printOperand(targetOffset);
+ p << "]";
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// custom<ResourceRegion>($operands, type($operands), $operand_sizes,
// type($results), $result_sizes,
// $tied_operands, $body)
@@ -944,6 +1138,9 @@
if (op.getResultSizes().size() != count || op.getValues().size() != count) {
return op.emitOpError() << "mismatched constant/result counts";
}
+ if (!llvm::all_equal(op.getResults().getTypes())) {
+ return op.emitOpError() << "all results must be of the same type";
+ }
return success();
}
@@ -1003,6 +1200,74 @@
}
//===----------------------------------------------------------------------===//
+// stream.parameter.load
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.read
+//===----------------------------------------------------------------------===//
+
+LogicalResult ParameterReadOp::verify() {
+ ParameterReadOp op = *this;
+ if (failed(verifyOpValueSizes(op, op.getTarget(), op.getTargetSize()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.write
+//===----------------------------------------------------------------------===//
+
+LogicalResult ParameterWriteOp::verify() {
+ ParameterWriteOp op = *this;
+ if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.gather
+//===----------------------------------------------------------------------===//
+
+LogicalResult ParameterGatherOp::verify() {
+ ParameterGatherOp op = *this;
+ size_t expectedCount = op.getSourceKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getTargetOffsets().size() != expectedCount ||
+ op.getTargetLengths().size() != expectedCount) {
+ return op.emitOpError()
+ << "requires that the source keys, source offsets, target offsets, "
+ "and target lengths are all 1:1";
+ }
+ if (failed(verifyOpValueSizes(op, op.getTarget(), op.getTargetSize()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.parameter.scatter
+//===----------------------------------------------------------------------===//
+
+LogicalResult ParameterScatterOp::verify() {
+ ParameterScatterOp op = *this;
+ size_t expectedCount = op.getTargetKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getSourceLengths().size() != expectedCount ||
+ op.getTargetOffsets().size() != expectedCount) {
+ return op.emitOpError()
+ << "requires that the source offsets, source lengths, target keys, "
+ "and target offsets are all 1:1";
+ }
+ if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// stream.file.constant
//===----------------------------------------------------------------------===//
@@ -3191,6 +3456,23 @@
return success();
}
+// static
+Value TimepointJoinOp::join(Location loc, ValueRange timepoints,
+ OpBuilder &builder) {
+ assert(!timepoints.empty() && "must have at least one timepoint");
+ if (timepoints.size() == 1)
+ return timepoints.front();
+ return builder.create<IREE::Stream::TimepointJoinOp>(
+ loc, builder.getType<IREE::Stream::TimepointType>(), timepoints);
+}
+
+// static
+Value TimepointJoinOp::join(ValueRange timepoints, OpBuilder &builder) {
+ return join(builder.getFusedLoc(llvm::to_vector(llvm::map_range(
+ timepoints, [](Value value) { return value.getLoc(); }))),
+ timepoints, builder);
+}
+
//===----------------------------------------------------------------------===//
// stream.timepoint.barrier
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 988cf86..1389eb2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -23,6 +23,57 @@
Stream_Op<mnemonic, !listconcat(traits, [Pure])>;
//===----------------------------------------------------------------------===//
+// Execution context ops
+//===----------------------------------------------------------------------===//
+
+def OpGroupContextOps : OpDocGroup {
+ let summary = "Execution context ops";
+ let description = [{
+ Operations for interacting with the execution context that stream operations
+ execute within.
+ }];
+}
+
+let opDocGroup = OpGroupContextOps in {
+
+def Stream_ContextResolveOp : Stream_PureOp<"context.resolve", [
+ Stream_AffinityOp,
+]> {
+ let summary = [{resolves low-level context resources based on type}];
+ let description = [{
+ WIP; allows for accessing the implementation details of lower-level dialects
+ such as the HAL. This will likely be reworked in the future to either
+ live inside other dialects, use some op interface instead of having a
+ dedicated op here, or remove the op entirely and make resolution happen
+ explicitly.
+
+ Examples:
+ ```
+ // Returns a HAL device.
+ = stream.context.resolve on(#something) : !hal.device
+ // Returns a HAL device and (optional) queue affinity.
+ = stream.context.resolve on(#something) : !hal.device, i64
+ // Returns a HAL allocator and (optional) queue affinity.
+ = stream.context.resolve on(#something) : !hal.allocator, i64
+ ```
+ }];
+
+ let arguments = (ins
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyType>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ attr-dict `:` type($results)
+ }];
+}
+
+} // OpGroupContextOps
+
+//===----------------------------------------------------------------------===//
// Generic resource ops
//===----------------------------------------------------------------------===//
@@ -582,6 +633,300 @@
} // OpGroupResourceOps
//===----------------------------------------------------------------------===//
+// Parameter I/O ops
+//===----------------------------------------------------------------------===//
+
+def OpGroupParameterOps : OpDocGroup {
+ let summary = "Resource parameter I/O ops";
+ let description = "Resource parameter I/O ops.";
+}
+
+let opDocGroup = OpGroupParameterOps in {
+
+def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_CmdPhaseOp,
+ Stream_TimelineOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{reads a resource from a parameter scope}];
+ let description = [{
+ Asynchronously reads a resource from an external parameter provider and
+ returns the resulting stream resource. Depending on the resource type this
+ may alias existing cached storage or be directly mapped to the parameter
+ origin or result in a copy as if `stream.resource.alloca` and
+ `stream.parameter.read` had been used.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$source_scope,
+ StrAttr:$source_key,
+ I64:$source_offset,
+ Stream_Size:$result_size,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result,
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ custom<ParameterReference>($source_scope, $source_key)
+ `` `[` $source_offset `]` `:`
+ type($result) `` `{` $result_size `}`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return getResultSize(); }
+ SmallVector<Value> getAwaitTimepoints() {
+ if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
+ }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ParameterReadOp : Stream_Op<"parameter.read", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_CmdPhaseOp,
+ Stream_TimelineOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{reads a resource from a parameter scope}];
+ let description = [{
+ Asynchronously reads a resource from an external parameter provider into the
+ provided target resource range.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$source_scope,
+ StrAttr:$source_key,
+ I64:$source_offset,
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$target_length,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ custom<ParameterReference>($source_scope, $source_key)
+ `` `[` $source_offset `]` `->`
+ $target `[` $target_offset `for` $target_length `]` `:`
+ type($target) `` `{` $target_size `}`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getTargetSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ SmallVector<Value> getAwaitTimepoints() {
+ if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
+ }
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ParameterWriteOp : Stream_Op<"parameter.write", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_CmdPhaseOp,
+ Stream_TimelineOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{writes a resource to a parameter scope}];
+ let description = [{
+ Asynchronously writes a resource to an external parameter provider from
+ the provided source resource range.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset,
+ Stream_Size:$source_length,
+ OptionalAttr<StrAttr>:$target_scope,
+ StrAttr:$target_key,
+ I64:$target_offset,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ $source `[` $source_offset `for` $source_length `]` `:`
+ type($source) `` `{` $source_size `}` `->`
+ custom<ParameterReference>($target_scope, $target_key)
+ `` `[` $target_offset `]`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getSourceSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ SmallVector<Value> getAwaitTimepoints() {
+ if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
+ }
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ParameterGatherOp : Stream_Op<"parameter.gather", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_CmdPhaseOp,
+ Stream_TimelineOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{gathers multiple resources from a parameter scope}];
+ let description = [{
+ Asynchronously gathers one or more resources into a single target stream
+ resource. This is equivalent to one `stream.parameter.read` per parameter
+ but allows implementations that can batch operations to do so without
+ additional timeline overhead.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$source_scope,
+ StrArrayAttr:$source_keys,
+ Variadic<I64>:$source_offsets,
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Variadic<Stream_Offset>:$target_offsets,
+ Variadic<Stream_Size>:$target_lengths,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ `{`
+ custom<ParameterGatherOperations>(
+ $source_scope, $source_keys, $source_offsets,
+ $target, type($target), $target_size, $target_offsets, $target_lengths)
+ `}`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getTargetSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ SmallVector<Value> getAwaitTimepoints() {
+ if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
+ }
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ParameterScatterOp : Stream_Op<"parameter.scatter", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_CmdPhaseOp,
+ Stream_TimelineOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{scatters multiple resources to a parameter scope}];
+ let description = [{
+ Asynchronously scatters one or more resources from a single source resource
+ into one or more parameters. This is equivalent to one
+ `stream.parameter.write` per parameter but allows implementations that can
+ batch operations to do so without additional overhead.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ Stream_Size:$source_size,
+ Variadic<Stream_Offset>:$source_offsets,
+ Variadic<Stream_Size>:$source_lengths,
+ OptionalAttr<StrAttr>:$target_scope,
+ StrArrayAttr:$target_keys,
+ Variadic<I64>:$target_offsets,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ `{`
+ custom<ParameterScatterOperations>(
+ $source, type($source), $source_size, $source_offsets, $source_lengths,
+ $target_scope, $target_keys, $target_offsets)
+ `}`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getSourceSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ SmallVector<Value> getAwaitTimepoints() {
+ if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
+ }
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+} // OpGroupParameterOps
+
+//===----------------------------------------------------------------------===//
// File ops
//===----------------------------------------------------------------------===//
@@ -644,9 +989,9 @@
let description = [{
Asynchronously reads a segment of a file into a resource.
- Some implementations this can stream directly from the file into
- device-local memory and should be preferred to manually staging memory
- through host buffers.
+ Some implementations can stream directly from the source file into
+ device-local memory and file ops should be preferred to manually staging
+ memory through host buffers.
}];
let arguments = (ins
@@ -704,8 +1049,8 @@
The file range must be valid within the file as this operation cannot
grow the underlying file storage.
- Some implementations this can stream directly from device-local memory into
- the file and should be preferred to manually staging memory
+ Some implementations can stream directly from device-local memory into the
+ target file and file ops should be preferred to manually staging memory
through host buffers.
}];
@@ -3330,6 +3675,13 @@
attr-dict-with-keyword
}];
+ let extraClassDeclaration = [{
+ // Joins one or more timepoints and returns a new timepoint representing a
+ // point on the timeline where all timepoints have been resolved.
+ static Value join(Location loc, ValueRange timepoints, OpBuilder &builder);
+ static Value join(ValueRange timepoints, OpBuilder &builder);
+ }];
+
let hasVerifier = 1;
let hasCanonicalizer = 1;
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 10e51f2..3dc63f9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -68,6 +68,51 @@
llvm::cl::desc(
"Fuses bindings that are mutable instead of leaving them split."),
llvm::cl::init(false));
+static llvm::cl::opt<IREE::Stream::MemoryModel> clResourceMemoryModel(
+ "iree-stream-resource-memory-model",
+ llvm::cl::desc("Memory model used for host-device resource memory access."),
+ llvm::cl::values(
+ clEnumValN(IREE::Stream::MemoryModel::Unified, "unified",
+ "Host and device memory are unified and there's "
+ "(practically) no performance cost for cross-access."),
+ clEnumValN(IREE::Stream::MemoryModel::Discrete, "discrete",
+ "Host and device memory are discrete and cross-access is "
+ "expensive.")),
+ llvm::cl::init(IREE::Stream::MemoryModel::Discrete));
+
+//===----------------------------------------------------------------------===//
+// custom<ParameterReference>($scope, $key)
+//===----------------------------------------------------------------------===//
+
+ParseResult parseParameterReference(AsmParser &parser, StringAttr &scopeAttr,
+ StringAttr &keyAttr) {
+ auto builder = parser.getBuilder();
+ StringAttr firstAttr;
+ if (failed(parser.parseCustomAttributeWithFallback(firstAttr,
+ builder.getNoneType()))) {
+ return failure();
+ }
+ if (failed(parser.parseOptionalColon())) {
+ keyAttr = firstAttr;
+ return success();
+ }
+ scopeAttr = firstAttr;
+ if (failed(parser.parseColon()) ||
+ failed(parser.parseCustomAttributeWithFallback(keyAttr,
+ builder.getNoneType()))) {
+ return failure();
+ }
+ return success();
+}
+
+void printParameterReference(AsmPrinter &p, StringAttr scopeAttr,
+ StringAttr keyAttr) {
+ if (scopeAttr) {
+ p << "\"" << scopeAttr.getValue() << "\"";
+ p << "::";
+ }
+ p << "\"" << keyAttr.getValue() << "\"";
+}
//===----------------------------------------------------------------------===//
// #stream.resource_config<...>
@@ -84,34 +129,55 @@
int64_t minBufferRangeAlignment = 0;
int64_t indexBits = 32;
bool aliasMutableBindings = false;
+ auto memoryModel = IREE::Stream::MemoryModel::Discrete;
while (failed(p.parseOptionalRBrace())) {
StringRef key;
- int64_t value = 0;
- if (failed(p.parseKeyword(&key)) || failed(p.parseEqual()) ||
- failed(p.parseInteger(value))) {
+ if (failed(p.parseKeyword(&key)) || failed(p.parseEqual())) {
return {};
}
if (key == "max_allocation_size") {
- maxAllocationSize = value;
+ if (failed(p.parseInteger(maxAllocationSize)))
+ return {};
} else if (key == "min_buffer_offset_alignment") {
- minBufferOffsetAlignment = value;
+ if (failed(p.parseInteger(minBufferOffsetAlignment)))
+ return {};
} else if (key == "max_buffer_range") {
- maxBufferRange = value;
+ if (failed(p.parseInteger(maxBufferRange)))
+ return {};
} else if (key == "min_buffer_range_alignment") {
- minBufferRangeAlignment = value;
+ if (failed(p.parseInteger(minBufferRangeAlignment)))
+ return {};
} else if (key == "index_bits") {
- indexBits = value;
+ if (failed(p.parseInteger(indexBits)))
+ return {};
} else if (key == "alias_mutable_bindings") {
- aliasMutableBindings = (bool)value;
+ StringRef value;
+ if (failed(p.parseKeyword(&value)))
+ return {};
+ if (value == "true")
+ aliasMutableBindings = true;
+ else if (value == "false")
+ aliasMutableBindings = false;
+ else
+ return {};
+ } else if (key == "memory_model") {
+ StringRef value;
+ if (failed(p.parseKeyword(&value)))
+ return {};
+ auto enumValue = symbolizeMemoryModel(value);
+ if (!enumValue.has_value())
+ return {};
+ memoryModel = enumValue.value();
}
(void)p.parseOptionalComma();
}
if (failed(p.parseGreater()))
return {};
- return ResourceConfigAttr::get(
- p.getContext(), maxAllocationSize, minBufferOffsetAlignment,
- maxBufferRange, minBufferRangeAlignment, indexBits, aliasMutableBindings);
+ return ResourceConfigAttr::get(p.getContext(), maxAllocationSize,
+ minBufferOffsetAlignment, maxBufferRange,
+ minBufferRangeAlignment, indexBits,
+ aliasMutableBindings, memoryModel);
}
void ResourceConfigAttr::print(AsmPrinter &p) const {
@@ -123,7 +189,8 @@
os << "max_buffer_range = " << getMaxBufferRange() << ", ";
os << "min_buffer_range_alignment = " << getMinBufferRangeAlignment() << ", ";
os << "index_bits = " << getIndexBits() << ", ";
- os << "alias_mutable_bindings = " << getAliasMutableBindings();
+ os << "alias_mutable_bindings = " << getAliasMutableBindings() << ", ";
+ os << "memory_model = " << stringifyMemoryModel(getMemoryModel());
os << "}>";
}
@@ -145,7 +212,11 @@
std::max(lhs.getMinBufferRangeAlignment(),
rhs.getMinBufferRangeAlignment()),
std::max(lhs.getIndexBits(), rhs.getIndexBits()),
- rhs.getAliasMutableBindings() && lhs.getAliasMutableBindings());
+ rhs.getAliasMutableBindings() && lhs.getAliasMutableBindings(),
+ (lhs.getMemoryModel() == IREE::Stream::MemoryModel::Unified &&
+ rhs.getMemoryModel() == IREE::Stream::MemoryModel::Unified)
+ ? IREE::Stream::MemoryModel::Unified
+ : IREE::Stream::MemoryModel::Discrete);
}
// static
@@ -157,7 +228,7 @@
return ResourceConfigAttr::get(
context, clResourceMaxAllocationSize, clResourceMinOffsetAlignment,
clResourceMaxRange, clResourceMinOffsetAlignment, clResourceIndexBits,
- clResourceAliasMutableBindings);
+ clResourceAliasMutableBindings, clResourceMemoryModel);
}
// static
@@ -185,6 +256,23 @@
}
//===----------------------------------------------------------------------===//
+// #stream.parameter.named<...>
+//===----------------------------------------------------------------------===//
+
+int64_t NamedParameterAttr::getStorageSize() const {
+ if (auto configAttr = getConfig()) {
+ if (auto lengthAttr = configAttr.getAs<IntegerAttr>("length")) {
+ return lengthAttr.getInt();
+ }
+ }
+ if (auto shapedType = getType().dyn_cast<ShapedType>()) {
+ return IREE::Util::getRoundedPhysicalStorageSize(shapedType);
+ } else {
+ return IREE::Util::getTypePhysicalStorageBitWidth(getType());
+ }
+}
+
+//===----------------------------------------------------------------------===//
// #stream.timepoint<...>
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index eae1d47..9835f9a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/IndexSet.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
@@ -57,6 +58,10 @@
} // namespace mlir
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAttr;
+} // namespace mlir::iree_compiler::IREE::Stream
+
// clang-format off: must be included after all LLVM/MLIR headers.
#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Dialect/Stream/IR/StreamAttrs.h.inc" // IWYU pragma: keep
@@ -64,27 +69,16 @@
#include "iree/compiler/Dialect/Stream/IR/StreamAttrInterfaces.h.inc" // IWYU pragma: export
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Stream {
-
+namespace mlir::iree_compiler::IREE::Stream {
#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.h.inc" // IWYU pragma: export
-
-} // namespace Stream
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace mlir::iree_compiler::IREE::Stream
// clang-format off: must be included after all LLVM/MLIR headers.
#define GET_TYPEDEF_CLASSES
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h.inc" // IWYU pragma: keep
// clang-format on
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Stream {
+namespace mlir::iree_compiler::IREE::Stream {
struct AsyncAccessRange {
ResourceAccessBitfield access;
@@ -96,9 +90,20 @@
#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export
-} // namespace Stream
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+//===----------------------------------------------------------------------===//
+// custom<ParameterReference>($scope, $key)
+//===----------------------------------------------------------------------===//
+
+ParseResult parseParameterReference(AsmParser &parser, StringAttr &scopeAttr,
+ StringAttr &keyAttr);
+void printParameterReference(AsmPrinter &p, StringAttr scopeAttr,
+ StringAttr keyAttr);
+static inline void printParameterReference(AsmPrinter &p, Operation *op,
+ StringAttr scopeAttr,
+ StringAttr keyAttr) {
+ printParameterReference(p, scopeAttr, keyAttr);
+}
+
+} // namespace mlir::iree_compiler::IREE::Stream
#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
index e168414..2149890 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
@@ -22,8 +22,11 @@
"channel_ops.mlir",
"cmd_folding.mlir",
"cmd_ops.mlir",
+ "context_ops.mlir",
"executable_ops.mlir",
"file_ops.mlir",
+ "parameter_folding.mlir",
+ "parameter_ops.mlir",
"resource_folding.mlir",
"resource_ops.mlir",
"tensor_folding.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
index 4835de2..7c4bf33 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
@@ -20,8 +20,11 @@
"channel_ops.mlir"
"cmd_folding.mlir"
"cmd_ops.mlir"
+ "context_ops.mlir"
"executable_ops.mlir"
"file_ops.mlir"
+ "parameter_folding.mlir"
+ "parameter_ops.mlir"
"resource_folding.mlir"
"resource_ops.mlir"
"tensor_folding.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
new file mode 100644
index 0000000..c324ff3
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
@@ -0,0 +1,12 @@
+// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @context_resolve
+func.func @context_resolve() {
+ // CHECK: = stream.context.resolve : !hal.allocator
+ %allocator = stream.context.resolve : !hal.allocator
+ // CHECK: = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
+ %device1, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
+ // CHECK: = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
+ %device0, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
+ return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir
new file mode 100644
index 0000000..ddeb54d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir
@@ -0,0 +1,54 @@
+// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @FoldParameterLoadTargetSubview
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index)
+func.func @FoldParameterLoadTargetSubview(%wait: !stream.timepoint, %offset: index, %length: index) -> (!stream.resource<constant>, !stream.timepoint) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[OFFSET_I64:.+]] = arith.index_cast %[[OFFSET]] : index to i64
+ // CHECK-DAG: %[[PARAMETER_OFFSET:.+]] = arith.addi %[[OFFSET_I64]], %c50_i64
+ // CHECK: %[[RESULT:.+]] = stream.parameter.load await(%[[WAIT]]) => "scope"::"key"[%[[PARAMETER_OFFSET]]] : !stream.resource<constant>{%[[LENGTH]]} => !stream.timepoint
+ %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ // CHECK-NOT: stream.resource.subview
+ %subview = stream.resource.subview %result[%offset] : !stream.resource<constant>{%c100} -> !stream.resource<constant>{%length}
+ // CHECK: return %[[RESULT]]
+ return %subview, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @FoldParameterReadTargetSubview
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[TARGET:.+]]: !stream.resource<transient>, %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index)
+func.func @FoldParameterReadTargetSubview(%wait: !stream.timepoint, %target: !stream.resource<transient>, %offset: index, %length: index) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[OFFSET_I64:.+]] = arith.index_cast %[[OFFSET]] : index to i64
+ // CHECK-DAG: %[[PARAMETER_OFFSET:.+]] = arith.addi %[[OFFSET_I64]], %c50_i64
+ // CHECK-DAG: %[[RESOURCE_OFFSET:.+]] = arith.addi %[[OFFSET]], %c100
+ // CHECK-NOT: stream.resource.subview
+ %subview = stream.resource.subview %target[%offset] : !stream.resource<transient>{%length} -> !stream.resource<transient>{%c300}
+ // CHECK: = stream.parameter.read await(%[[WAIT]]) => "scope"::"key"[%[[PARAMETER_OFFSET]]] -> %[[TARGET]][%[[RESOURCE_OFFSET]] for %c200] : !stream.resource<transient>{%[[LENGTH]]} => !stream.timepoint
+ %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %subview[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @FoldParameterWriteSourceSubview
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[SOURCE:.+]]: !stream.resource<transient>, %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index)
+func.func @FoldParameterWriteSourceSubview(%wait: !stream.timepoint, %source: !stream.resource<transient>, %offset: index, %length: index) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[OFFSET_I64:.+]] = arith.index_cast %[[OFFSET]] : index to i64
+ // CHECK-DAG: %[[PARAMETER_OFFSET:.+]] = arith.addi %[[OFFSET_I64]], %c50_i64
+ // CHECK-DAG: %[[RESOURCE_OFFSET:.+]] = arith.addi %[[OFFSET]], %c100
+ // CHECK-NOT: stream.resource.subview
+ %subview = stream.resource.subview %source[%offset] : !stream.resource<transient>{%length} -> !stream.resource<transient>{%c300}
+ // CHECK: = stream.parameter.write await(%[[WAIT]]) => %[[SOURCE]][%[[RESOURCE_OFFSET]] for %c200] : !stream.resource<transient>{%[[LENGTH]]} -> "scope"::"key"[%[[PARAMETER_OFFSET]]] => !stream.timepoint
+ %timepoint = stream.parameter.write await(%wait) => %subview[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir
new file mode 100644
index 0000000..0e6ba69
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir
@@ -0,0 +1,139 @@
+// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK: util.global private @parameter_unscoped = #stream.parameter.named<"key"> : tensor<10xf32>
+util.global private @parameter_unscoped = #stream.parameter.named<"key"> : tensor<10xf32>
+// CHECK: util.global private @parameter_scoped = #stream.parameter.named<"scope"::"key"> : tensor<10xf32>
+util.global private @parameter_scoped = #stream.parameter.named<"scope"::"key"> : tensor<10xf32>
+// CHECK: util.global private @parameter_config = #stream.parameter.named<"scope"::"key", {some.config = "hello"}> : tensor<10xf32>
+util.global private @parameter_config = #stream.parameter.named<"scope"::"key", {some.config = "hello"}> : tensor<10xf32>
+
+// -----
+
+// CHECK-LABEL: @parameterLoad
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint)
+func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK: = stream.parameter.load await(%[[WAIT]]) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterLoadNoScope
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint)
+func.func @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK: = stream.parameter.load await(%[[WAIT]]) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ %result, %result_timepoint = stream.parameter.load await(%wait) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterRead
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[TARGET:.+]]: !stream.resource<transient>)
+func.func @parameterRead(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK: = stream.parameter.read await(%[[WAIT]]) => "scope"::"key"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterWrite
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[SOURCE:.+]]: !stream.resource<transient>)
+func.func @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK: = stream.parameter.write await(%[[WAIT]]) => %[[SOURCE]][%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterGather
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[TARGET:.+]]: !stream.resource<transient>)
+func.func @parameterGather(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ %c300 = arith.constant 300 : index
+ // CHECK: = stream.parameter.gather await(%[[WAIT]]) => {
+ // CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !stream.resource<transient>{%c300},
+ // CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !stream.resource<transient>{%c300},
+ // CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !stream.resource<transient>{%c300}
+ // CHECK-NEXT: } => !stream.timepoint
+ %timepoint = stream.parameter.gather await(%wait) => {
+ "scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
+ "scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300},
+ "scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource<transient>{%c300}
+ } => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterGatherNoScope
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[TARGET:.+]]: !stream.resource<transient>)
+func.func @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c300 = arith.constant 300 : index
+ // CHECK: = stream.parameter.gather await(%[[WAIT]]) => {
+ // CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !stream.resource<transient>{%c300},
+ // CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !stream.resource<transient>{%c300}
+ // CHECK-NEXT: } => !stream.timepoint
+ %timepoint = stream.parameter.gather await(%wait) => {
+ "key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
+ "key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300}
+ } => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterScatter
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[SOURCE:.+]]: !stream.resource<transient>)
+func.func @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ %c300 = arith.constant 300 : index
+ // CHECK: = stream.parameter.scatter await(%[[WAIT]]) => {
+ // CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64],
+ // CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64],
+ // CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
+ // CHECK-NEXT: } => !stream.timepoint
+ %timepoint = stream.parameter.scatter await(%wait) => {
+ %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64],
+ %source[%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64],
+ %source[%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
+ } => !stream.timepoint
+ return %timepoint : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp
index b493916..48c9e2c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp
@@ -143,10 +143,9 @@
}
}
for (auto constantOp : usageInfo.bufferConstantOps) {
- if (auto serializableAttr =
- constantOp.getValue()
- .dyn_cast<IREE::Util::SerializableAttrInterface>()) {
- constantSize += serializableAttr.getStorageSize();
+ if (auto storageAttr =
+ constantOp.getValue().dyn_cast<IREE::Util::SizedStorageAttr>()) {
+ constantSize += storageAttr.getStorageSize();
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index 87c02fa..de82242 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -48,11 +48,8 @@
// Returns the length, in bytes, of the constant value prior to alignment or
// padding.
uint64_t getStorageSize() const {
- if (auto serializableAttr =
- llvm::dyn_cast<IREE::Util::SerializableAttrInterface>(value)) {
- return serializableAttr.getStorageSize();
- } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
- return denseAttr.getRawData().size();
+ if (auto storageAttr = dyn_cast<IREE::Util::SizedStorageAttr>(value)) {
+ return storageAttr.getStorageSize();
} else {
assert(false && "invalid constant attr type");
return 0;
@@ -63,33 +60,40 @@
struct PackedSpan {
// Original slice this span represents.
ConstantSlice slice;
- // Byte offset within the storage buffer.
+ // Byte offset within the target storage buffer.
uint64_t offset = 0;
- // Length of the valid data when padded out.
+ // Length of the valid data when padded out in the target storage buffer.
// This is only accounting for the padding of the valid data itself and not
// any additional padding for other spans within the buffer (like start
// offset alignment).
uint64_t length = 0;
};
+// A storage resource backed by data packed in storage.
+// A storage resource may be comprised of data packed by the compiler into its
+// at-rest form (allowing for easy single operation loads/mappings) or gathered
+// from disparate storage locations (more expensive). Degenerate resources
+// may only contain a single packed logical resource which allows for easier
+// parameter loading with zero-copies at the cost of more runtime overhead.
struct StorageResource {
// Fused location of all spans that make up this storage buffer.
Location loc;
// Total size in bytes (including padding).
uint64_t totalSize = 0;
// Constant spans packed into this resource.
- SmallVector<PackedSpan, 8> spans;
+ SmallVector<PackedSpan> spans;
// Packed byte data that must be embedded in the final module.
// It must be written with an alignment as required by the constraints.
- IREE::Util::CompositeAttr data;
+ // If not set then each span may have unique storage.
+ IREE::Util::CompositeAttr packedData;
};
// Buckets |slices| into 1+ storage resources based on |resourceConfig|.
-static SmallVector<StorageResource, 8> bucketValuesIntoStorageResources(
+static SmallVector<StorageResource> bucketValuesIntoStorageResources(
ArrayRef<ConstantSlice> slices,
IREE::Stream::ResourceConfigAttr resourceConfig) {
// TODO(benvanik): replace with a better strategy (best-fit, etc).
- SmallVector<StorageResource, 8> storageBuffers;
+ SmallVector<StorageResource> storageBuffers;
storageBuffers.push_back({UnknownLoc::get(resourceConfig.getContext())});
StorageResource *currentBuffer = &storageBuffers.back();
for (auto slice : slices) {
@@ -178,15 +182,15 @@
offset += tailPadding;
}
- storageBuffer.data = IREE::Util::CompositeAttr::get(context, values);
- assert(storageBuffer.data && "unable to build composite attr");
+ storageBuffer.packedData = IREE::Util::CompositeAttr::get(context, values);
+ assert(storageBuffer.packedData && "unable to build composite attr");
}
// Returns zero or more storage resources and the spans values map into.
// Assume that |slices| have been ordered by prior passes and that order may
// have some performance-sensitivity (constants are grouped by
// locality/lifetime/etc).
-static SmallVector<StorageResource, 8>
+static SmallVector<StorageResource>
computePackingMap(ArrayRef<ConstantSlice> slices,
IREE::Stream::ResourceConfigAttr resourceConfig,
MLIRContext *context) {
@@ -243,26 +247,121 @@
Value resourceSize;
};
-static TimepointResource buildFileRead(
- Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
- IREE::Stream::ResourceType resourceType, StorageResource storageResource,
- Value storageResourceSize, Value storageBuffer, Value storageBufferSize,
- IndexSet &indexSet, OpBuilder &builder) {
+struct ParameterSlice {
+ IREE::Stream::NamedParameterAttr parameterAttr;
+ Value sourceOffset;
+ Value sourceLength;
+};
+
+static ParameterSlice getParameterSlice(Location loc, Attribute value,
+ IndexSet &indexSet,
+ OpBuilder &builder) {
+ auto parameterAttr = cast<IREE::Stream::NamedParameterAttr>(value);
+ Value sourceOffset;
+ Value sourceLength;
+ if (auto configAttr = parameterAttr.getConfig()) {
+ if (auto offsetAttr = configAttr.getAs<IntegerAttr>("offset")) {
+ sourceOffset =
+ builder.create<arith::ConstantIntOp>(loc, offsetAttr.getInt(), 64);
+ }
+ if (auto lengthAttr = configAttr.getAs<IntegerAttr>("length")) {
+ sourceLength = indexSet.get(lengthAttr.getInt());
+ }
+ }
+ if (!sourceOffset)
+ sourceOffset = builder.create<arith::ConstantIntOp>(loc, 0, 64);
+ if (!sourceLength)
+ sourceLength = indexSet.get(parameterAttr.getStorageSize());
+ return ParameterSlice{parameterAttr, sourceOffset, sourceLength};
+}
+
+static TimepointResource
+buildParameterLoad(Location loc, Value awaitTimepoint,
+ IREE::Stream::AffinityAttr affinityAttr, Type targetType,
+ Value targetSize, const PackedSpan &packedSpan,
+ IndexSet &indexSet, OpBuilder &builder) {
+ auto parameterSlice =
+ getParameterSlice(loc, packedSpan.slice.value, indexSet, builder);
+ auto loadOp = builder.create<IREE::Stream::ParameterLoadOp>(
+ loc, targetType, builder.getType<IREE::Stream::TimepointType>(),
+ parameterSlice.parameterAttr.getScope(),
+ parameterSlice.parameterAttr.getKey(), parameterSlice.sourceOffset,
+ parameterSlice.sourceLength, awaitTimepoint, affinityAttr);
+ return TimepointResource{loadOp.getResultTimepoint(), loadOp.getResult(),
+ loadOp.getResultSize()};
+}
+
+static TimepointResource
+buildParameterGather(Location loc, Value awaitTimepoint,
+ IREE::Stream::AffinityAttr affinityAttr, Type targetType,
+ Value targetSize, ArrayRef<PackedSpan> packedSpans,
+ IndexSet &indexSet, OpBuilder &builder) {
// Allocate the resulting storage resource of the final resource type.
auto allocOp = builder.create<IREE::Stream::ResourceAllocOp>(
- storageResource.loc, resourceType, storageResourceSize,
+ loc, targetType, targetSize,
+ /*uninitialized=*/builder.getUnitAttr(), affinityAttr);
+
+ // Parameters may be from multiple scopes - bucket by scope and gather from
+ // each in turn.
+ llvm::MapVector<StringAttr, SmallVector<PackedSpan>> scopeSpans;
+ for (auto &packedSpan : packedSpans) {
+ auto parameterAttr =
+ cast<IREE::Stream::NamedParameterAttr>(packedSpan.slice.value);
+ scopeSpans[parameterAttr.getScope()].push_back(packedSpan);
+ }
+
+ // Gather from each unique scope.
+ SmallVector<Value> gatherTimepoints;
+ for (auto &[scope, packedSpans] : scopeSpans) {
+ SmallVector<Attribute> sourceKeys;
+ SmallVector<Value> sourceOffsets;
+ SmallVector<Value> targetOffsets;
+ SmallVector<Value> targetLengths;
+ sourceKeys.reserve(packedSpans.size());
+ for (auto &packedSpan : packedSpans) {
+ auto parameterSlice =
+ getParameterSlice(loc, packedSpan.slice.value, indexSet, builder);
+ sourceKeys.push_back(parameterSlice.parameterAttr.getKey());
+ sourceOffsets.push_back(parameterSlice.sourceOffset);
+ targetOffsets.push_back(indexSet.get(packedSpan.offset));
+ targetLengths.push_back(indexSet.get(packedSpan.length));
+ }
+ auto gatherOp = builder.create<IREE::Stream::ParameterGatherOp>(
+ loc, builder.getType<IREE::Stream::TimepointType>(), scope,
+ builder.getArrayAttr(sourceKeys), sourceOffsets, allocOp.getResult(),
+ allocOp.getResultSize(0), targetOffsets, targetLengths, awaitTimepoint,
+ affinityAttr);
+ gatherTimepoints.push_back(gatherOp.getResultTimepoint());
+ }
+
+ // Wait until all gathers have completed.
+ Value readyTimepoint =
+ IREE::Stream::TimepointJoinOp::join(gatherTimepoints, builder);
+ return TimepointResource{readyTimepoint, allocOp.getResult(),
+ allocOp.getResultSize(0)};
+}
+
+static TimepointResource buildFileRead(Location loc, Value awaitTimepoint,
+ IREE::Stream::AffinityAttr affinityAttr,
+ IREE::Stream::ResourceType resourceType,
+ Value storageResourceSize,
+ Value storageBuffer,
+ Value storageBufferSize,
+ IndexSet &indexSet, OpBuilder &builder) {
+ // Allocate the resulting storage resource of the final resource type.
+ auto allocOp = builder.create<IREE::Stream::ResourceAllocOp>(
+ loc, resourceType, storageResourceSize,
/*uninitialized=*/builder.getUnitAttr(), affinityAttr);
// Create the file backed by the constant resource buffer.
auto fileOp = builder.create<IREE::Stream::FileConstantOp>(
- storageResource.loc, storageBuffer, storageBufferSize, indexSet.get(0),
+ loc, storageBuffer, storageBufferSize, indexSet.get(0),
storageResourceSize, affinityAttr);
// Issue asynchronous file read into the buffer.
- auto zeroI64 =
- builder.create<arith::ConstantIntOp>(storageResource.loc, 0, 64);
+ auto zeroI64 = builder.create<arith::ConstantIntOp>(loc, 0, 64);
auto readOp = builder.create<IREE::Stream::FileReadOp>(
- storageResource.loc, fileOp.getResult(), zeroI64, allocOp.getResult(),
+ loc, fileOp.getResult(), zeroI64, allocOp.getResult(),
allocOp.getResultSize(0), indexSet.get(0), storageResourceSize,
awaitTimepoint, affinityAttr);
@@ -276,14 +375,14 @@
// Returns a timepoint indicating the operation has completed.
static TimepointResource buildTryMapConstantResource(
Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
- IREE::Stream::ResourceType resourceType, StorageResource storageResource,
- Value storageResourceSize, Value storageBuffer, Value storageBufferSize,
- IndexSet &indexSet, OpBuilder &builder) {
+ IREE::Stream::ResourceType resourceType, Value storageResourceSize,
+ Value storageBuffer, Value storageBufferSize, IndexSet &indexSet,
+ OpBuilder &builder) {
// Try mapping; this may fail if the device can't use the storage buffer as
// the type of resource requested.
auto tryMapOp = builder.create<IREE::Stream::ResourceTryMapOp>(
- storageResource.loc, builder.getI1Type(), resourceType, storageBuffer,
- indexSet.get(0), storageResourceSize, affinityAttr);
+ loc, builder.getI1Type(), resourceType, storageBuffer, indexSet.get(0),
+ storageResourceSize, affinityAttr);
// If we are able to directly map the resources then we don't need to wait.
// Otherwise we need to stage the storage buffer into memory via the file
@@ -300,8 +399,8 @@
[&](OpBuilder &elseBuilder, Location loc) {
auto readResult =
buildFileRead(loc, awaitTimepoint, affinityAttr, resourceType,
- storageResource, storageResourceSize, storageBuffer,
- storageBufferSize, indexSet, elseBuilder);
+ storageResourceSize, storageBuffer, storageBufferSize,
+ indexSet, elseBuilder);
elseBuilder.create<scf::YieldOp>(loc, ValueRange{
readResult.timepoint,
readResult.resource,
@@ -312,32 +411,14 @@
return TimepointResource{ifTimepoint, ifResource, storageResourceSize};
}
-static Value generateUpload(Value awaitTimepoint,
- IREE::Stream::ResourceConstantsOp constantsOp,
- IREE::Stream::Lifetime lifetime,
- IREE::Stream::ResourceConfigAttr resourceConfig,
- IndexSet &indexSet, OpBuilder &builder) {
- // Gather the slices produced by this constant pooling op.
- SmallVector<ConstantSlice> slices;
- slices.reserve(constantsOp.getResults().size());
- for (auto [result, resultSize, value] :
- llvm::zip_equal(constantsOp.getResults(), constantsOp.getResultSizes(),
- constantsOp.getValues())) {
- auto resourceType =
- llvm::cast<IREE::Stream::ResourceType>(result.getType());
- if (resourceType.getLifetime() != lifetime)
- continue;
- slices.push_back(ConstantSlice{
- result,
- resultSize,
- value,
- });
- }
-
+static Value generateSerializedUpload(
+ Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
+ IREE::Stream::ResourceConfigAttr resourceConfig,
+ ArrayRef<ConstantSlice> slices, IndexSet &indexSet, OpBuilder &builder) {
// Perform the packing of dense values to compute the storage resources we
// will need and where each value will be placed.
auto storageResources =
- computePackingMap(slices, resourceConfig, constantsOp.getContext());
+ computePackingMap(slices, resourceConfig, builder.getContext());
if (storageResources.empty())
return nullptr;
@@ -353,26 +434,25 @@
// them once regardless of how many strategies we emit IR for.
Value currentTimepoint = awaitTimepoint;
for (auto &storageResource : storageResources) {
+ // Serialized resources are stored as packed host data.
Value storageBuffer = builder.create<IREE::Util::BufferConstantOp>(
- storageResource.loc, /*name=*/nullptr, storageResource.data,
+ storageResource.loc, /*name=*/nullptr, storageResource.packedData,
builder.getIndexAttr(resourceConfig.getMinBufferOffsetAlignment()),
/*mimeType=*/nullptr);
- auto resourceSize = indexSet.get(storageResource.totalSize);
// If this is producing constants (vs variables) we can try to go on a
// fast-path where we directly map the constant memory. If producing
// variables then we always need to stage and clone.
TimepointResource uploadedResource;
+ auto resourceSize = indexSet.get(storageResource.totalSize);
if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
uploadedResource = buildTryMapConstantResource(
- constantsOp.getLoc(), currentTimepoint, constantsOp.getAffinityAttr(),
- resourceType, storageResource, resourceSize, storageBuffer,
- resourceSize, indexSet, builder);
+ storageResource.loc, currentTimepoint, affinityAttr, resourceType,
+ resourceSize, storageBuffer, resourceSize, indexSet, builder);
} else {
uploadedResource = buildFileRead(
- constantsOp.getLoc(), currentTimepoint, constantsOp.getAffinityAttr(),
- resourceType, storageResource, resourceSize, storageBuffer,
- resourceSize, indexSet, builder);
+ storageResource.loc, currentTimepoint, affinityAttr, resourceType,
+ resourceSize, storageBuffer, resourceSize, indexSet, builder);
}
for (auto &span : storageResource.spans) {
@@ -390,6 +470,118 @@
return currentTimepoint;
}
+static Value generateParameterUpload(
+ Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
+ IREE::Stream::ResourceConfigAttr resourceConfig,
+ ArrayRef<ConstantSlice> slices, IndexSet &indexSet, OpBuilder &builder) {
+ auto anyResult = slices.front().result;
+ auto resourceType =
+ llvm::cast<IREE::Stream::ResourceType>(anyResult.getType());
+
+ // Perform the packing of dense values to compute the storage resources we
+ // will need and where each value will be placed unless we have a chance to
+ // reuse parameter storage. This is a big switch today (either we try to
+ // emit one resource per parameter for loading _or_ we gather everything) but
+ // could be refined to only try loading large resources while we pack the
+ // small resources. e.g. try to reuse a 1GB parameter but pack 1000 128B
+ // parameters together.
+ SmallVector<StorageResource> storageResources;
+ if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant &&
+ resourceConfig.getMemoryModel() == IREE::Stream::MemoryModel::Unified) {
+ for (auto &slice : slices) {
+ uint64_t sliceSize = slice.getStorageSize();
+ storageResources.push_back(StorageResource{slice.result.getLoc(),
+ sliceSize,
+ {
+ PackedSpan{
+ slice,
+ /*offset=*/0,
+ /*length=*/sliceSize,
+ },
+ }});
+ }
+ } else {
+ storageResources =
+ computePackingMap(slices, resourceConfig, builder.getContext());
+ }
+ if (storageResources.empty())
+ return nullptr;
+
+ // Emit the parameter loads or gathers for each unique resource.
+ SmallVector<Value> uploadTimepoints;
+ for (auto &storageResource : storageResources) {
+ // Parameter-backed resource that we can either load or gather.
+ // Loads are only possible if we are using the parameter as a constant and
+ // it is a single span as we can't pack externally owned parameters.
+ TimepointResource uploadedResource;
+ auto resourceSize = indexSet.get(storageResource.totalSize);
+ if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant &&
+ storageResource.spans.size() == 1) {
+ uploadedResource = buildParameterLoad(
+ storageResource.loc, awaitTimepoint, affinityAttr, resourceType,
+ resourceSize, storageResource.spans.front(), indexSet, builder);
+ } else {
+ uploadedResource = buildParameterGather(
+ storageResource.loc, awaitTimepoint, affinityAttr, resourceType,
+ resourceSize, storageResource.spans, indexSet, builder);
+ }
+
+ for (auto &span : storageResource.spans) {
+ auto loc = span.slice.result.getLoc();
+ auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
+ loc, uploadedResource.resource, uploadedResource.resourceSize,
+ indexSet.get(span.offset), span.slice.resultSize);
+ span.slice.result.replaceAllUsesWith(subviewOp.getResult());
+ }
+
+ uploadTimepoints.push_back(uploadedResource.timepoint);
+ }
+
+ // Join on storage timepoints for our transitive dependencies to await.
+ return IREE::Stream::TimepointJoinOp::join(uploadTimepoints, builder);
+}
+
+static Value generateUploads(Value awaitTimepoint,
+ IREE::Stream::ResourceConstantsOp constantsOp,
+ IREE::Stream::ResourceConfigAttr resourceConfig,
+ IndexSet &indexSet, OpBuilder &builder) {
+ // Split the slices based on whether they are sourced from serialized data or
+ // externally-defined parameters.
+ // TODO(benvanik): remove stream.resource.constants and this coupling;
+ // parameters should be handled by a dedicated pass. This is a hack that
+ // allows us to reuse the packing code for performing variable parameter packs
+ // and have everything happen atomically but is pretty terrible.
+ SmallVector<ConstantSlice> serializedSlices;
+ SmallVector<ConstantSlice> parameterSlices;
+ for (auto [result, resultSize, value] :
+ llvm::zip_equal(constantsOp.getResults(), constantsOp.getResultSizes(),
+ constantsOp.getValues())) {
+ auto slice = ConstantSlice{
+ result,
+ resultSize,
+ value,
+ };
+ if (isa<IREE::Stream::NamedParameterAttr>(value)) {
+ parameterSlices.push_back(slice);
+ } else {
+ serializedSlices.push_back(slice);
+ }
+ }
+
+ SmallVector<Value> uploadTimepoints;
+ if (!serializedSlices.empty()) {
+ uploadTimepoints.push_back(generateSerializedUpload(
+ awaitTimepoint, constantsOp.getAffinityAttr(), resourceConfig,
+ serializedSlices, indexSet, builder));
+ }
+ if (!parameterSlices.empty()) {
+ uploadTimepoints.push_back(generateParameterUpload(
+ awaitTimepoint, constantsOp.getAffinityAttr(), resourceConfig,
+ parameterSlices, indexSet, builder));
+ }
+ return IREE::Stream::TimepointJoinOp::join(uploadTimepoints, builder);
+}
+
//===----------------------------------------------------------------------===//
// -iree-stream-pack-constants
//===----------------------------------------------------------------------===//
@@ -426,25 +618,19 @@
auto resourceConfig =
IREE::Stream::ResourceConfigAttr::lookup(constantsOp);
+ // Packing creates a lot of index values given that all the sizes are
+ // statically-known - CSE would collapse them but we use an IndexSet to
+ // reduce the IR churn.
OpBuilder builder(constantsOp);
IndexSet indexSet(constantsOp.getLoc(), builder);
indexSet.populate(constantsOp.getResultSizes());
// Perform upload/processing for immutable and mutable constants.
- Value currentTimepoint =
- builder.create<IREE::Stream::TimepointImmediateOp>(
- constantsOp.getLoc());
- if (auto uploadTimepoint = generateUpload(
- currentTimepoint, constantsOp, IREE::Stream::Lifetime::Constant,
- resourceConfig, indexSet, builder)) {
- currentTimepoint = uploadTimepoint;
- }
- if (auto uploadTimepoint = generateUpload(
- currentTimepoint, constantsOp, IREE::Stream::Lifetime::Variable,
- resourceConfig, indexSet, builder)) {
- currentTimepoint = uploadTimepoint;
- }
- constantsOp.getResultTimepoint().replaceAllUsesWith(currentTimepoint);
+ Value awaitTimepoint = builder.create<IREE::Stream::TimepointImmediateOp>(
+ constantsOp.getLoc());
+ auto uploadTimepoint = generateUploads(awaitTimepoint, constantsOp,
+ resourceConfig, indexSet, builder);
+ constantsOp.getResultTimepoint().replaceAllUsesWith(uploadTimepoint);
constantsOp.erase();
});
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
index cdfb952..ea26f84 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
@@ -591,9 +591,8 @@
if (newTimepoints.empty()) {
op.getAwaitTimepointMutable().clear();
} else {
- auto newTimepoint = builder.createOrFold<IREE::Stream::TimepointJoinOp>(
- op.getLoc(), builder.getType<IREE::Stream::TimepointType>(),
- newTimepoints.takeVector());
+ auto newTimepoint = IREE::Stream::TimepointJoinOp::join(
+ op.getLoc(), newTimepoints.takeVector(), builder);
op.getAwaitTimepointMutable().assign(newTimepoint);
}
op.getResourceOperandsMutable().assign(newOperands);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 367f4aa..6779aae 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -1137,17 +1137,24 @@
return true;
}
-// Extracts stream.async.constant ops from |executeOp| into their own dedicated
-// stream.resource.constants upload op. The uploaded constants will be captured
-// by the region for use within as if they had still existed in there.
+// Extracts stream.async.constant ops with the given lifetime from |executeOp|
+// into their own dedicated stream.resource.constants upload op. The uploaded
+// constants will be captured by the region for use within as if they had still
+// existed in there.
static std::optional<ConstantAllocation>
-extractConstants(IREE::Stream::AsyncExecuteOp executeOp,
- OpBuilder &externalBuilder) {
- // Gather all constant ops from the region, if any.
- auto constantOps =
- llvm::to_vector(executeOp.getOps<IREE::Stream::AsyncConstantOp>());
+extractConstantsWithLifetime(IREE::Stream::AsyncExecuteOp executeOp,
+ IREE::Stream::Lifetime lifetime,
+ OpBuilder &externalBuilder) {
+ auto constantOps = llvm::to_vector(
+ llvm::make_filter_range(executeOp.getOps<IREE::Stream::AsyncConstantOp>(),
+ [&](IREE::Stream::AsyncConstantOp op) {
+ return op.getResult()
+ .getType()
+ .cast<IREE::Stream::ResourceType>()
+ .getLifetime() == lifetime;
+ }));
if (constantOps.empty())
- return std::nullopt;
+ return {};
// Allocate a new constant upload op and insert a subview for each constant.
SmallVector<Location> locs;
@@ -1191,9 +1198,29 @@
allocation.reservations.push_back(reservation);
}
+
return allocation;
}
+// Extracts stream.async.constant ops from |executeOp| into their own dedicated
+// stream.resource.constants upload ops per lifetime. The uploaded constants
+// will be captured by the region for use within as if they had still existed in
+// there.
+static SmallVector<ConstantAllocation>
+extractConstants(IREE::Stream::AsyncExecuteOp executeOp,
+ OpBuilder &externalBuilder) {
+ SmallVector<ConstantAllocation> allocations;
+ if (auto allocation = extractConstantsWithLifetime(
+ executeOp, IREE::Stream::Lifetime::Constant, externalBuilder)) {
+ allocations.push_back(std::move(allocation).value());
+ }
+ if (auto allocation = extractConstantsWithLifetime(
+ executeOp, IREE::Stream::Lifetime::Variable, externalBuilder)) {
+ allocations.push_back(std::move(allocation).value());
+ }
+ return allocations;
+}
+
//===----------------------------------------------------------------------===//
// Execution region result allocation
//===----------------------------------------------------------------------===//
@@ -1384,10 +1411,10 @@
// op. We'll then capture the result and use that to initialize variables and
// constants within the region. Note that this removes ops from the region and
// as such we want to run it first before we go allocate transients.
- auto constantAllocation = extractConstants(executeOp, externalBuilder);
- if (constantAllocation.has_value()) {
+ auto constantAllocations = extractConstants(executeOp, externalBuilder);
+ for (auto &constantAllocation : constantAllocations) {
bool anyCaptured = false;
- for (auto &reservation : constantAllocation->reservations) {
+ for (auto &reservation : constantAllocation.reservations) {
if (reservation.capturedArg) {
newOperands.push_back(reservation.resource);
newOperandSizes.push_back(reservation.resourceSize);
@@ -1408,7 +1435,7 @@
});
}
- auto awaitTimepoint = constantAllocation->constantsOp.getResultTimepoint();
+ auto awaitTimepoint = constantAllocation.constantsOp.getResultTimepoint();
if (anyCaptured) {
// The execute region must depend on the constant upload as one or more
// constants are used. All this code could be much more clever about
@@ -1421,7 +1448,7 @@
awaitTimepoint.printAsOperand(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
- for (auto &reservation : constantAllocation->reservations) {
+ for (auto &reservation : constantAllocation.reservations) {
auto resourceRange =
ResourceRange(reservation.capturedArg, reservation.resourceSize);
scope.mapResourceRange(reservation.constantOp, resourceRange,
@@ -1441,7 +1468,7 @@
}
// Replace results of escaping uploads with the upload values.
- for (auto &reservation : constantAllocation->reservations) {
+ for (auto &reservation : constantAllocation.reservations) {
auto result = findTiedYieldResult(reservation.constantOp.getResult());
if (!result)
continue;
@@ -1456,8 +1483,6 @@
llvm::dbgs() << "\n";
});
}
- } else {
- LLVM_DEBUG(llvm::dbgs() << " - no constants found\n");
}
// Compute an updated set of operands/results. After allocation all results
@@ -1658,10 +1683,8 @@
if (newAwaitTimepoints.size() == 1) {
newAwaitTimepoint = newAwaitTimepoints.front();
} else if (newAwaitTimepoints.size() > 1) {
- newAwaitTimepoint =
- executeBuilder.createOrFold<IREE::Stream::TimepointJoinOp>(
- executeOp.getLoc(), newAwaitTimepoints.front().getType(),
- newAwaitTimepoints);
+ newAwaitTimepoint = IREE::Stream::TimepointJoinOp::join(
+ executeOp.getLoc(), newAwaitTimepoints, executeBuilder);
}
// Recreate the execution op with all the new arguments. Note that we drop
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
index abafb17..39319b0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
@@ -13,6 +13,7 @@
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c24 = arith.constant 24 : index
+ %c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c255_i32 = arith.constant 255 : i32
@@ -21,20 +22,25 @@
// CHECK-NEXT: !stream.resource<constant>{%c8} = dense<3> : tensor<8xi8>,
// CHECK-NEXT: !stream.resource<constant>{%c16} = dense<4> : tensor<4x2xi16>
+ // Initialized variables get hoisted into a dedicated op.
+ // CHECK: %[[VAR_RET:.+]], %[[VAR_TIMEPOINT:.+]] = stream.resource.constants :
+ // CHECK-NEXT: !stream.resource<variable>{%c32} = dense<5> : tensor<8xi32>
+
// Remaining ops run in a normal execution region.
// CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute await(%[[OPERAND_TIMEPOINT]])
// CHECK-SAME: => with(%[[OPERAND]]
// CHECK-NEXT: stream.cmd.fill
- %results:3, %result_timepoint = stream.async.execute await(%timepoint) => with(%operand as %capture: !stream.resource<transient>{%size}) -> (!stream.resource<constant>{%c8}, !stream.resource<constant>{%c16}, !stream.resource<transient>{%size}) {
+ %results:4, %result_timepoint = stream.async.execute await(%timepoint) => with(%operand as %capture: !stream.resource<transient>{%size}) -> (!stream.resource<constant>{%c8}, !stream.resource<constant>{%c16}, !stream.resource<variable>{%c32}, !stream.resource<transient>{%size}) {
%0 = stream.async.constant : !stream.resource<constant>{%c8} = dense<3> : tensor<8xi8>
%1 = stream.async.constant : !stream.resource<constant>{%c16} = dense<4> : tensor<4x2xi16>
- %2 = stream.async.fill %c255_i32, %capture[%c0 to %c128 for %c128] : i32 -> %capture as !stream.resource<transient>{%size}
- stream.yield %0, %1, %2 : !stream.resource<constant>{%c8}, !stream.resource<constant>{%c16}, !stream.resource<transient>{%size}
+ %2 = stream.async.constant : !stream.resource<variable>{%c32} = dense<5> : tensor<8xi32>
+ %3 = stream.async.fill %c255_i32, %capture[%c0 to %c128 for %c128] : i32 -> %capture as !stream.resource<transient>{%size}
+ stream.yield %0, %1, %2, %3 : !stream.resource<constant>{%c8}, !stream.resource<constant>{%c16}, !stream.resource<variable>{%c32}, !stream.resource<transient>{%size}
} => !stream.timepoint
// Join the two async ops (constant upload and execution should overlap).
- // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[CST_TIMEPOINT]], %[[EXEC_TIMEPOINT]])
+ // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[CST_TIMEPOINT]], %[[VAR_TIMEPOINT]], %[[EXEC_TIMEPOINT]])
// CHECK: util.optimization_barrier %[[JOIN]] : !stream.timepoint
util.optimization_barrier %result_timepoint : !stream.timepoint
@@ -42,8 +48,10 @@
util.optimization_barrier %results#0 : !stream.resource<constant>
// CHECK: util.optimization_barrier %[[CST_RETS]]#1
util.optimization_barrier %results#1 : !stream.resource<constant>
+ // CHECK: util.optimization_barrier %[[VAR_RET]]
+ util.optimization_barrier %results#2 : !stream.resource<variable>
// CHECK: util.optimization_barrier %[[OPERAND]]
- util.optimization_barrier %results#2 : !stream.resource<transient>
+ util.optimization_barrier %results#3 : !stream.resource<transient>
return
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
index 3c74cbf..65c9c92 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -532,9 +532,8 @@
ArrayRef<Attribute> valueAttrs) {
int64_t calculatedLength = 0;
for (auto valueAttr : valueAttrs) {
- if (auto serializableAttr =
- llvm::dyn_cast<SerializableAttrInterface>(valueAttr)) {
- calculatedLength += serializableAttr.getStorageSize();
+ if (auto storageAttr = llvm::dyn_cast<SizedStorageAttr>(valueAttr)) {
+ calculatedLength += storageAttr.getStorageSize();
} else {
return {};
}
@@ -548,9 +547,8 @@
int64_t totalLength, ArrayAttr valueAttrs) {
int64_t calculatedLength = 0;
for (auto valueAttr : valueAttrs) {
- if (auto serializableAttr =
- llvm::dyn_cast<SerializableAttrInterface>(valueAttr)) {
- calculatedLength += serializableAttr.getStorageSize();
+ if (auto storageAttr = llvm::dyn_cast<SizedStorageAttr>(valueAttr)) {
+ calculatedLength += storageAttr.getStorageSize();
} else {
return emitError() << "value is not serializable: " << valueAttr;
}
@@ -650,6 +648,42 @@
}
//===----------------------------------------------------------------------===//
+// SizedStorageAttr implementations
+//===----------------------------------------------------------------------===//
+
+struct SizedStorageDenseElementsAttrModel
+ : public SizedStorageAttr::ExternalModel<SizedStorageDenseElementsAttrModel,
+ DenseIntOrFPElementsAttr> {
+ int64_t getStorageSize(Attribute baseAttr) const {
+ auto attr = llvm::cast<ElementsAttr>(baseAttr);
+ return IREE::Util::getRoundedPhysicalStorageSize(
+ attr.getNumElements(),
+ cast<ShapedType>(attr.getType()).getElementType());
+ }
+};
+
+struct SizedStorageDenseResourceElementsAttrModel
+ : public SizedStorageAttr::ExternalModel<
+ SizedStorageDenseResourceElementsAttrModel,
+ DenseResourceElementsAttr> {
+ int64_t getStorageSize(Attribute baseAttr) const {
+ auto attr = llvm::cast<DenseResourceElementsAttr>(baseAttr);
+ return IREE::Util::getRoundedPhysicalStorageSize(
+ attr.getNumElements(), attr.getType().getElementType());
+ }
+};
+
+// We don't include NUL terminators as it's 2023.
+struct SizedStorageStringAttrModel
+ : public SizedStorageAttr::ExternalModel<SizedStorageStringAttrModel,
+ StringAttr> {
+ int64_t getStorageSize(Attribute baseAttr) const {
+ auto attr = llvm::cast<StringAttr>(baseAttr);
+ return attr.getValue().size();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// SerializableAttrInterface implementations
//===----------------------------------------------------------------------===//
@@ -658,17 +692,10 @@
struct SerializableDenseElementsAttrModel
: public SerializableAttrInterface::ExternalModel<
SerializableDenseElementsAttrModel, DenseIntOrFPElementsAttr> {
- int64_t getStorageSize(Attribute baseAttr) const {
- auto attr = llvm::cast<ElementsAttr>(baseAttr);
- return IREE::Util::getRoundedPhysicalStorageSize(
- attr.getNumElements(),
- cast<ShapedType>(attr.getType()).getElementType());
- }
-
LogicalResult serializeToVector(Attribute baseAttr, Location loc,
llvm::endianness endian,
SmallVectorImpl<char> &buffer) const {
- buffer.resize(getStorageSize(baseAttr));
+ buffer.resize(cast<SizedStorageAttr>(baseAttr).getStorageSize());
return serializeToBuffer(baseAttr, loc, endian, buffer);
}
@@ -684,7 +711,7 @@
llvm::raw_ostream &os) const {
// NOTE: not all ostream implementations handle this but for buffering ones
// it can really help.
- os.reserveExtraSpace(getStorageSize(baseAttr));
+ os.reserveExtraSpace(cast<SizedStorageAttr>(baseAttr).getStorageSize());
auto elementsAttr = llvm::cast<DenseElementsAttr>(baseAttr);
if (elementsAttr.isSplat()) {
@@ -711,16 +738,10 @@
: public SerializableAttrInterface::ExternalModel<
SerializableDenseResourceElementsAttrModel,
DenseResourceElementsAttr> {
- int64_t getStorageSize(Attribute baseAttr) const {
- auto attr = llvm::cast<DenseResourceElementsAttr>(baseAttr);
- return IREE::Util::getRoundedPhysicalStorageSize(
- attr.getNumElements(), attr.getType().getElementType());
- }
-
LogicalResult serializeToVector(Attribute baseAttr, Location loc,
llvm::endianness endian,
SmallVectorImpl<char> &buffer) const {
- buffer.resize(getStorageSize(baseAttr));
+ buffer.resize(cast<SizedStorageAttr>(baseAttr).getStorageSize());
return serializeToBuffer(baseAttr, loc, endian, buffer);
}
@@ -747,7 +768,7 @@
"values or pass --iree-util-zero-fill-elided-attrs for "
"testing and expect invalid execution results";
}
- os.write_zeros(getStorageSize(baseAttr));
+ os.write_zeros(cast<SizedStorageAttr>(baseAttr).getStorageSize());
return success();
}
@@ -761,15 +782,10 @@
struct SerializableStringAttrModel
: public SerializableAttrInterface::ExternalModel<
SerializableStringAttrModel, StringAttr> {
- int64_t getStorageSize(Attribute baseAttr) const {
- auto attr = llvm::cast<StringAttr>(baseAttr);
- return attr.getValue().size();
- }
-
LogicalResult serializeToVector(Attribute baseAttr, Location loc,
llvm::endianness endian,
SmallVectorImpl<char> &buffer) const {
- buffer.resize(getStorageSize(baseAttr));
+ buffer.resize(cast<SizedStorageAttr>(baseAttr).getStorageSize());
return serializeToBuffer(baseAttr, loc, endian, buffer);
}
@@ -785,7 +801,7 @@
llvm::raw_ostream &os) const {
// NOTE: not all ostream implementations handle this but for buffering ones
// it can really help.
- os.reserveExtraSpace(getStorageSize(baseAttr));
+ os.reserveExtraSpace(cast<SizedStorageAttr>(baseAttr).getStorageSize());
auto stringAttr = llvm::cast<StringAttr>(baseAttr);
os.write(stringAttr.data(), stringAttr.size());
return success();
@@ -813,13 +829,17 @@
// up in the stack - things that end up here are generally already in a target
// encoding.
auto &context = *getContext();
- DenseIntElementsAttr::attachInterface<SerializableDenseElementsAttrModel>(
+ DenseIntElementsAttr::attachInterface<SizedStorageDenseElementsAttrModel,
+ SerializableDenseElementsAttrModel>(
context);
- DenseFPElementsAttr::attachInterface<SerializableDenseElementsAttrModel>(
+ DenseFPElementsAttr::attachInterface<SizedStorageDenseElementsAttrModel,
+ SerializableDenseElementsAttrModel>(
context);
DenseResourceElementsAttr::attachInterface<
+ SizedStorageDenseResourceElementsAttrModel,
SerializableDenseResourceElementsAttrModel>(context);
- StringAttr::attachInterface<SerializableStringAttrModel>(context);
+ StringAttr::attachInterface<SizedStorageStringAttrModel,
+ SerializableStringAttrModel>(context);
}
} // namespace Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
index b6c6ac5..1c856ef 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
@@ -19,6 +19,9 @@
def Util_BytePatternAttr : AttrDef<Util_Dialect, "BytePattern", [
TypedAttrInterface,
+ DeclareAttrInterfaceMethods<Util_SizedStorageAttr, [
+ "getStorageSize",
+ ]>,
DeclareAttrInterfaceMethods<Util_SerializableAttrInterface, [
"serializeToBuffer",
"serializeToStream",
@@ -64,6 +67,9 @@
//===----------------------------------------------------------------------===//
def Util_CompositeAttr : AttrDef<Util_Dialect, "Composite", [
+ DeclareAttrInterfaceMethods<Util_SizedStorageAttr, [
+ "getStorageSize",
+ ]>,
DeclareAttrInterfaceMethods<Util_SerializableAttrInterface, [
"serializeToBuffer",
"serializeToStream",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index 4a8ade0..0ec0d49 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -915,6 +915,28 @@
];
}
+def Util_SizedStorageAttr : AttrInterface<"SizedStorageAttr"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Interface used to query storage requirements for an attribute that is backed
+ by physical storage (memory, disk, or external API).
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the storage size in bytes required by the attribute value.
+ If the value is sub-byte aligned the storage size will be rounded up to
+ the next whole byte.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getStorageSize",
+ /*args=*/(ins)
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// IREE::Util::ReferenceTypeInterface
//===----------------------------------------------------------------------===//
@@ -1088,7 +1110,10 @@
// IREE::Util::SerializableAttrInterface
//===----------------------------------------------------------------------===//
-def Util_SerializableAttrInterface : AttrInterface<"SerializableAttrInterface"> {
+def Util_SerializableAttrInterface :
+ AttrInterface<"SerializableAttrInterface", [
+ Util_SizedStorageAttr,
+ ]> {
let cppNamespace = "::mlir::iree_compiler::IREE::Util";
let description = [{
@@ -1102,17 +1127,6 @@
let methods = [
InterfaceMethod<
/*desc=*/[{
- Returns the storage size in bytes required by the serialized value.
- Any of the serialization methods will write precisely this number of
- bytes. If the value is sub-byte aligned the storage size will be rounded
- up to the next whole byte.
- }],
- /*retTy=*/"int64_t",
- /*methodName=*/"getStorageSize",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
Serializes the attribute to the given byte vector. The vector will be
resized to the total storage size upon return.
}],
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index ea55cd4..2b1f87e 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -812,8 +812,8 @@
// During A->B->C dialect conversion, the type may not be legal so be
// defensive.
auto operand = getOperand();
- if (auto sizeAwareType = llvm::dyn_cast<IREE::Util::SizeAwareTypeInterface>(
- operand.getType())) {
+ if (auto sizeAwareType =
+ dyn_cast<IREE::Util::SizeAwareTypeInterface>(operand.getType())) {
Operation *op = this->getOperation();
if (auto sizeValue = sizeAwareType.findSizeValue(operand, op->getBlock(),
Block::iterator(op))) {
@@ -824,11 +824,10 @@
// If the source is a constant then we can calculate that immediately.
if (auto constantOp = dyn_cast_or_null<IREE::Util::BufferConstantOp>(
operand.getDefiningOp())) {
- if (auto attr =
- llvm::dyn_cast_if_present<IREE::Util::SerializableAttrInterface>(
- constantOp.getValue())) {
- return IntegerAttr::get(IndexType::get(attr.getContext()),
- attr.getStorageSize());
+ if (auto storageAttr = dyn_cast_if_present<IREE::Util::SizedStorageAttr>(
+ constantOp.getValue())) {
+ return IntegerAttr::get(IndexType::get(storageAttr.getContext()),
+ storageAttr.getStorageSize());
}
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel
index cbe37c1..5edd21e 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel
@@ -96,7 +96,7 @@
)
iree_tablegen_doc(
- name = "HALInlineDialecDocGen",
+ name = "HALInlineDialectDocGen",
tbl_outs = [
(
[
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/CMakeLists.txt
index 48c56e8..0b1ea0e 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/CMakeLists.txt
@@ -69,7 +69,7 @@
iree_tablegen_doc(
NAME
- HALInlineDialecDocGen
+ HALInlineDialectDocGen
TD_FILE
"HALInlineOps.td"
OUTS
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel
index 0cd3ed7..be7f37c 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel
@@ -96,7 +96,7 @@
)
iree_tablegen_doc(
- name = "HALLoaderDialecDocGen",
+ name = "HALLoaderDialectDocGen",
tbl_outs = [
(
[
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/CMakeLists.txt
index eb4840e..f9117c6 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/CMakeLists.txt
@@ -69,7 +69,7 @@
iree_tablegen_doc(
NAME
- HALLoaderDialecDocGen
+ HALLoaderDialectDocGen
TD_FILE
"HALLoaderOps.td"
OUTS
diff --git a/compiler/src/iree/compiler/Modules/IO/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/BUILD.bazel
new file mode 100644
index 0000000..522ca5d
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/CMakeLists.txt
new file mode 100644
index 0000000..92497d8
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/BUILD.bazel
new file mode 100644
index 0000000..40a4c17
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/BUILD.bazel
@@ -0,0 +1,22 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+c_embed_data(
+ name = "io_parameters_imports",
+ srcs = ["io_parameters.imports.mlir"],
+ c_file_output = "io_parameters.imports.c",
+ flatten = True,
+ h_file_output = "io_parameters.imports.h",
+ identifier = "iree_io_parameters_imports",
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/CMakeLists.txt
new file mode 100644
index 0000000..ddd20fe
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/CMakeLists.txt
@@ -0,0 +1,28 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_c_embed_data(
+ NAME
+ io_parameters_imports
+ SRCS
+ "io_parameters.imports.mlir"
+ C_FILE_OUTPUT
+ "io_parameters.imports.c"
+ H_FILE_OUTPUT
+ "io_parameters.imports.h"
+ IDENTIFIER
+ "iree_io_parameters_imports"
+ FLATTEN
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/BUILD.bazel
new file mode 100644
index 0000000..522ca5d
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..f1f66bd
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/BUILD.bazel
new file mode 100644
index 0000000..4d14361
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/BUILD.bazel
@@ -0,0 +1,36 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "ParamsToVM",
+ srcs = [
+ "Patterns.cpp",
+ ],
+ hdrs = [
+ "Patterns.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Dialect/VM/Conversion",
+ "//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM",
+ "//compiler/src/iree/compiler/Dialect/VM/IR",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/IR",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/CMakeLists.txt
new file mode 100644
index 0000000..f203753
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/CMakeLists.txt
@@ -0,0 +1,35 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ ParamsToVM
+ HDRS
+ "Patterns.h"
+ SRCS
+ "Patterns.cpp"
+ DEPS
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRPass
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::Conversion::HALToVM
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::VM::Conversion
+ iree::compiler::Dialect::VM::Conversion::StandardToVM
+ iree::compiler::Dialect::VM::IR
+ iree::compiler::Modules::IO::Parameters::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp
new file mode 100644
index 0000000..953503f
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp
@@ -0,0 +1,312 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.h"
+
+#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
+#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+static Value getStringRodata(Location loc, StringAttr attr,
+ OpBuilder &builder) {
+ if (!attr) {
+ return builder.create<IREE::VM::ConstRefZeroOp>(
+ loc, IREE::VM::RefType::get(builder.getType<IREE::VM::BufferType>()));
+ }
+ return builder.create<IREE::VM::RodataInlineOp>(loc, attr);
+}
+
+struct LoadOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::LoadOp> {
+ LoadOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ loadOp, importOp.getSymNameAttr(),
+ IREE::VM::RefType::get(loadOp.getResult().getType()),
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(loadOp.getLoc(), adaptor.getSourceScopeAttr(),
+ rewriter),
+ getStringRodata(loadOp.getLoc(), adaptor.getSourceKeyAttr(),
+ rewriter),
+ adaptor.getSourceOffset(),
+ rewriter.create<IREE::VM::ConstI32Op>(
+ loadOp.getLoc(), (uint32_t)adaptor.getMemoryTypes()),
+ rewriter.create<IREE::VM::ConstI32Op>(
+ loadOp.getLoc(), (uint32_t)adaptor.getBufferUsage()),
+ adaptor.getLength(),
+ });
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+struct ReadOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::ReadOp> {
+ ReadOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::ReadOp readOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ readOp, importOp.getSymNameAttr(), TypeRange{},
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(readOp.getLoc(), adaptor.getSourceScopeAttr(),
+ rewriter),
+ getStringRodata(readOp.getLoc(), adaptor.getSourceKeyAttr(),
+ rewriter),
+ adaptor.getSourceOffset(),
+ adaptor.getTargetBuffer(),
+ adaptor.getTargetOffset(),
+ adaptor.getLength(),
+ });
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+struct WriteOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::WriteOp> {
+ WriteOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::WriteOp writeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ writeOp, importOp.getSymNameAttr(), TypeRange{},
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(writeOp.getLoc(), adaptor.getTargetScopeAttr(),
+ rewriter),
+ getStringRodata(writeOp.getLoc(), adaptor.getTargetKeyAttr(),
+ rewriter),
+ adaptor.getTargetOffset(),
+ adaptor.getSourceBuffer(),
+ adaptor.getSourceOffset(),
+ adaptor.getLength(),
+ });
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+// TODO(benvanik): make a vm.rodata.table or something that returns the
+// offset/length and data buffers. We could then do a whole-program analysis to
+// build a single data table with multiple views into it.
+static std::pair<Value, Value> buildKeyTable(Location loc, ArrayAttr keysAttr,
+ OpBuilder &builder) {
+ SmallVector<int32_t> table;
+ SmallVector<Attribute> dataAttrs;
+ size_t dataSize = 0;
+ for (auto key : keysAttr.getAsRange<StringAttr>()) {
+ table.push_back(dataSize);
+ table.push_back(key.size());
+ dataAttrs.push_back(key);
+ dataSize += key.size();
+ }
+ Value tableRodata = builder.create<IREE::VM::RodataInlineOp>(
+ loc, IREE::VM::RefType::get(builder.getType<IREE::VM::BufferType>()),
+ builder.getI32VectorAttr(table));
+ Value stringRodata = builder.create<IREE::VM::RodataInlineOp>(
+ loc, IREE::VM::RefType::get(builder.getType<IREE::VM::BufferType>()),
+ IREE::Util::CompositeAttr::get(builder.getContext(), dataAttrs));
+ return {tableRodata, stringRodata};
+}
+
+static Value buildIndirectSpans(Location loc, ValueRange parameterOffsets,
+ ValueRange bufferOffsets,
+ ValueRange bufferLengths, OpBuilder &builder) {
+ // Build the rodata containing all constant values and the list of dynamic
+ // updates we'll need to perform. We assume that 95-100% of values are
+ // constant and optimize for that - if this changes we can make this more
+ // sophisticated to reduce binary size and runtime overhead.
+ SmallVector<std::pair<size_t, Value>> dynamicUpdates;
+ SmallVector<int64_t> values;
+ auto recordValue = [&](Value value) {
+ APInt constantValue;
+ if (matchPattern(value, m_ConstantInt(&constantValue))) {
+ values.push_back(constantValue.getZExtValue());
+ } else {
+ values.push_back(0);
+ dynamicUpdates.push_back(std::make_pair(values.size(), value));
+ }
+ };
+ for (auto [parameterOffset, bufferOffset, bufferLength] :
+ llvm::zip_equal(parameterOffsets, bufferOffsets, bufferLengths)) {
+ recordValue(parameterOffset);
+ recordValue(bufferOffset);
+ recordValue(bufferLength);
+ }
+ Value rodataBuffer = builder.create<IREE::VM::RodataInlineOp>(
+ loc, IREE::VM::RefType::get(builder.getType<IREE::VM::BufferType>()),
+ builder.getI64VectorAttr(values));
+ if (dynamicUpdates.empty()) {
+ // Fast-path for all constant data.
+ return rodataBuffer;
+ }
+
+ // Clone the rodata so we can mutate it.
+ Value rodataSize = builder.create<IREE::VM::BufferLengthOp>(
+ loc, builder.getI64Type(), rodataBuffer);
+ Value clonedBuffer = builder.create<IREE::VM::BufferCloneOp>(
+ loc, IREE::VM::RefType::get(builder.getType<IREE::VM::BufferType>()),
+ rodataBuffer, builder.create<IREE::VM::ConstI32ZeroOp>(loc), rodataSize,
+ builder.create<IREE::VM::ConstI32Op>(loc, sizeof(uint32_t)));
+
+ // Perform all updates.
+ for (auto [index, value] : dynamicUpdates) {
+ builder.create<IREE::VM::BufferStoreI64Op>(
+ loc, clonedBuffer, builder.create<IREE::VM::ConstI64Op>(loc, index),
+ value);
+ }
+
+ return clonedBuffer;
+}
+
+struct GatherOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::GatherOp> {
+ GatherOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::GatherOp gatherOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto [keyTable, keyData] =
+ buildKeyTable(gatherOp.getLoc(), adaptor.getSourceKeysAttr(), rewriter);
+ auto spans = buildIndirectSpans(
+ gatherOp.getLoc(), adaptor.getSourceOffsets(),
+ adaptor.getTargetOffsets(), adaptor.getTargetLengths(), rewriter);
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ gatherOp, importOp.getSymNameAttr(), TypeRange{},
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(gatherOp.getLoc(), adaptor.getSourceScopeAttr(),
+ rewriter),
+ adaptor.getTargetBuffer(),
+ keyTable,
+ keyData,
+ spans,
+ });
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+struct ScatterOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::ScatterOp> {
+ ScatterOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::ScatterOp scatterOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto [keyTable, keyData] = buildKeyTable(
+ scatterOp.getLoc(), adaptor.getTargetKeysAttr(), rewriter);
+ auto spans = buildIndirectSpans(
+ scatterOp.getLoc(), adaptor.getTargetOffsets(),
+ adaptor.getSourceOffsets(), adaptor.getSourceLengths(), rewriter);
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ scatterOp, importOp.getSymNameAttr(), TypeRange{},
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(scatterOp.getLoc(), adaptor.getTargetScopeAttr(),
+ rewriter),
+ adaptor.getSourceBuffer(),
+ keyTable,
+ keyData,
+ spans,
+ });
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+} // namespace
+
+void populateIOParametersToVMPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ SymbolTable &importSymbols,
+ RewritePatternSet &patterns) {
+ patterns.insert<LoadOpConversion>(context, importSymbols, typeConverter,
+ "io_parameters.load");
+ patterns.insert<ReadOpConversion>(context, importSymbols, typeConverter,
+ "io_parameters.read");
+ patterns.insert<WriteOpConversion>(context, importSymbols, typeConverter,
+ "io_parameters.write");
+ patterns.insert<GatherOpConversion>(context, importSymbols, typeConverter,
+ "io_parameters.gather");
+ patterns.insert<ScatterOpConversion>(context, importSymbols, typeConverter,
+ "io_parameters.scatter");
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.h
new file mode 100644
index 0000000..b338992
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.h
@@ -0,0 +1,26 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_PARAMSTOVM_PATTERNS_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_PARAMSTOVM_PATTERNS_H_
+
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+// Populates conversion patterns from the io_parameters dialect to the VM
+// dialect.
+void populateIOParametersToVMPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ SymbolTable &importSymbols,
+ RewritePatternSet &patterns);
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_PARAMSTOVM_PATTERNS_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel
new file mode 100644
index 0000000..06483fc
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "parameter_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/CMakeLists.txt
new file mode 100644
index 0000000..f4db472
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "parameter_ops.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir
new file mode 100644
index 0000000..de9b662
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir
@@ -0,0 +1,134 @@
+// RUN: iree-opt --iree-vm-target-index-bits=64 --split-input-file \
+// RUN: --iree-vm-conversion --mlir-print-local-scope %s | FileCheck %s
+
+// CHECK-LABEL: @parameterLoad
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>)
+func.func @parameterLoad(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence) -> !hal.buffer {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
+ // CHECK: %[[TARGET_BUFFER:.+]] = vm.call @io_parameters.load
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %c48, %c527363, %c100)
+ %target_buffer = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source("scope"::"key")[%c50_i64] type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") : !hal.buffer{%c100}
+ // CHECK: return %[[TARGET_BUFFER]]
+ return %target_buffer : !hal.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @parameterLoadNoScope
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>)
+func.func @parameterLoadNoScope(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence) -> !hal.buffer {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.const.ref.zero : !vm.buffer
+ // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
+ // CHECK: %[[TARGET_BUFFER:.+]] = vm.call @io_parameters.load
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %c48, %c527363, %c100)
+ %target_buffer = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source("key")[%c50_i64] type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") : !hal.buffer{%c100}
+ // CHECK: return %[[TARGET_BUFFER]]
+ return %target_buffer : !hal.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @parameterRead
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>, %[[TARGET_BUFFER:.+]]: !vm.ref<!hal.buffer>)
+func.func @parameterRead(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence, %target_buffer: !hal.buffer) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
+ // CHECK: vm.call @io_parameters.read
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %[[TARGET_BUFFER]], %c100, %c200)
+ io_parameters.read<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source("scope"::"key")[%c50_i64] target(%target_buffer : !hal.buffer)[%c100] length(%c200)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @parameterWrite
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>, %[[SOURCE_BUFFER:.+]]: !vm.ref<!hal.buffer>)
+func.func @parameterWrite(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence, %source_buffer: !hal.buffer) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
+ // CHECK: vm.call @io_parameters.write
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %[[SOURCE_BUFFER]], %c100, %c200
+ io_parameters.write<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source(%source_buffer : !hal.buffer)[%c100] target("scope"::"key")[%c50_i64] length(%c200)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @parameterGather
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>, %[[TARGET_BUFFER:.+]]: !vm.ref<!hal.buffer>)
+func.func @parameterGather(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence, %target_buffer: !hal.buffer) {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ // CHECK-DAG: %[[KEY_TABLE:.+]] = vm.rodata.inline : !vm.buffer = dense<[0, 4, 4, 4, 8, 4]> : vector<6xi32>
+ // CHECK-DAG: %[[KEY_DATA:.+]] = vm.rodata.inline : !vm.buffer = #util.composite<12xi8, [
+ // CHECK-NEXT: "key0",
+ // CHECK-NEXT: "key1",
+ // CHECK-NEXT: "key2",
+ // CHECK-NEXT: ]>
+ // CHECK-DAG: %[[SPANS:.+]] = vm.rodata.inline : !vm.buffer = dense<[50, 100, 200, 51, 101, 201, 52, 102, 202]> : vector<9xi64>
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK: vm.call @io_parameters.gather
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[TARGET_BUFFER]], %[[KEY_TABLE]], %[[KEY_DATA]], %[[SPANS]])
+ io_parameters.gather<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) {
+ "scope"::"key0"[%c50_i64] -> %target_buffer[%c100 for %c200] : !hal.buffer,
+ "scope"::"key1"[%c51_i64] -> %target_buffer[%c101 for %c201] : !hal.buffer,
+ "scope"::"key2"[%c52_i64] -> %target_buffer[%c102 for %c202] : !hal.buffer
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @parameterScatter
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>, %[[SOURCE_BUFFER:.+]]: !vm.ref<!hal.buffer>)
+func.func @parameterScatter(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence, %source_buffer: !hal.buffer) {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ // CHECK-DAG: %[[KEY_TABLE:.+]] = vm.rodata.inline : !vm.buffer = dense<[0, 4, 4, 4, 8, 4]> : vector<6xi32>
+ // CHECK-DAG: %[[KEY_DATA:.+]] = vm.rodata.inline : !vm.buffer = #util.composite<12xi8, [
+ // CHECK-NEXT: "key0",
+ // CHECK-NEXT: "key1",
+ // CHECK-NEXT: "key2",
+ // CHECK-NEXT: ]>
+ // CHECK-DAG: %[[SPANS:.+]] = vm.rodata.inline : !vm.buffer = dense<[50, 100, 200, 51, 101, 201, 52, 102, 202]> : vector<9xi64>
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK: vm.call @io_parameters.scatter
+ // CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
+ // CHECK-SAME: %[[SCOPE]], %[[SOURCE_BUFFER]], %[[KEY_TABLE]], %[[KEY_DATA]], %[[SPANS]])
+ io_parameters.scatter<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) {
+ %source_buffer[%c100 for %c200] : !hal.buffer -> "scope"::"key0"[%c50_i64],
+ %source_buffer[%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64],
+ %source_buffer[%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64]
+ }
+ return
+}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/BUILD.bazel
new file mode 100644
index 0000000..5e1b984
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/BUILD.bazel
@@ -0,0 +1,38 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "StreamToParams",
+ srcs = [
+ "Patterns.cpp",
+ ],
+ hdrs = [
+ "Patterns.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+ "//compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL:Utils",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/CMakeLists.txt
new file mode 100644
index 0000000..4059580
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/CMakeLists.txt
@@ -0,0 +1,37 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ StreamToParams
+ HDRS
+ "Patterns.h"
+ SRCS
+ "Patterns.cpp"
+ DEPS
+ LLVMSupport
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRPass
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::Conversion::StreamToHAL::Utils
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Modules::IO::Parameters::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp
new file mode 100644
index 0000000..2105945
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp
@@ -0,0 +1,186 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.h"
+
+#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+struct ParameterLoadOpPattern
+ : public OpConversionPattern<IREE::Stream::ParameterLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ParameterLoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp.getLoc();
+ auto [device, queueAffinity] =
+ lookupDeviceAndQueueAffinityFor(loadOp, rewriter);
+
+ // Gather wait/signal fence, which are optional.
+ Value waitFence =
+ getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
+ Value signalFence = getOrCreateSignalFence(
+ loc, device, loadOp.getResultTimepoint(), rewriter);
+
+ // Derive the allocation requirements.
+ auto resourceType =
+ llvm::cast<IREE::Stream::ResourceType>(loadOp.getResult().getType());
+ auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
+ auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
+ if (failed(deriveAllowedResourceBufferBits(loc, resourceType, memoryTypes,
+ bufferUsage))) {
+ return failure();
+ }
+
+ // Queue operation, which acts like an allocation.
+ Value result = rewriter.create<IREE::IO::Parameters::LoadOp>(
+ loc, rewriter.getType<IREE::HAL::BufferType>(), device, queueAffinity,
+ waitFence, signalFence, adaptor.getSourceScopeAttr(),
+ adaptor.getSourceKeyAttr(), adaptor.getSourceOffset(), memoryTypes,
+ bufferUsage, adaptor.getResultSize());
+
+ rewriter.replaceOp(loadOp, {result, signalFence});
+ return success();
+ }
+};
+
+struct ParameterReadOpPattern
+ : public OpConversionPattern<IREE::Stream::ParameterReadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ParameterReadOp readOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = readOp.getLoc();
+ auto [device, queueAffinity] =
+ lookupDeviceAndQueueAffinityFor(readOp, rewriter);
+
+ // Gather wait/signal fence, which are optional.
+ Value waitFence =
+ getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
+ Value signalFence = getOrCreateSignalFence(
+ loc, device, readOp.getResultTimepoint(), rewriter);
+
+ // Queue operation.
+ rewriter.create<IREE::IO::Parameters::ReadOp>(
+ loc, device, queueAffinity, waitFence, signalFence,
+ adaptor.getSourceScopeAttr(), adaptor.getSourceKeyAttr(),
+ adaptor.getSourceOffset(), adaptor.getTarget(),
+ adaptor.getTargetOffset(), adaptor.getTargetLength());
+
+ rewriter.replaceOp(readOp, {signalFence});
+ return success();
+ }
+};
+
+struct ParameterWriteOpPattern
+ : public OpConversionPattern<IREE::Stream::ParameterWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ParameterWriteOp writeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = writeOp.getLoc();
+ auto [device, queueAffinity] =
+ lookupDeviceAndQueueAffinityFor(writeOp, rewriter);
+
+ // Gather wait/signal fence, which are optional.
+ Value waitFence =
+ getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
+ Value signalFence = getOrCreateSignalFence(
+ loc, device, writeOp.getResultTimepoint(), rewriter);
+
+ // Queue operation.
+ rewriter.create<IREE::IO::Parameters::WriteOp>(
+ loc, device, queueAffinity, waitFence, signalFence, adaptor.getSource(),
+ adaptor.getSourceOffset(), adaptor.getTargetScopeAttr(),
+ adaptor.getTargetKeyAttr(), adaptor.getTargetOffset(),
+ adaptor.getSourceLength());
+
+ rewriter.replaceOp(writeOp, {signalFence});
+ return success();
+ }
+};
+
+struct ParameterGatherOpPattern
+ : public OpConversionPattern<IREE::Stream::ParameterGatherOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ParameterGatherOp gatherOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = gatherOp.getLoc();
+ auto [device, queueAffinity] =
+ lookupDeviceAndQueueAffinityFor(gatherOp, rewriter);
+
+ // Gather wait/signal fence, which are optional.
+ Value waitFence =
+ getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
+ Value signalFence = getOrCreateSignalFence(
+ loc, device, gatherOp.getResultTimepoint(), rewriter);
+
+ // Queue operation.
+ rewriter.create<IREE::IO::Parameters::GatherOp>(
+ loc, device, queueAffinity, waitFence, signalFence,
+ adaptor.getSourceScopeAttr(), adaptor.getSourceKeysAttr(),
+ adaptor.getSourceOffsets(), adaptor.getTarget(),
+ adaptor.getTargetOffsets(), adaptor.getTargetLengths());
+
+ rewriter.replaceOp(gatherOp, {signalFence});
+ return success();
+ }
+};
+
+struct ParameterScatterOpPattern
+ : public OpConversionPattern<IREE::Stream::ParameterScatterOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::ParameterScatterOp scatterOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = scatterOp.getLoc();
+ auto [device, queueAffinity] =
+ lookupDeviceAndQueueAffinityFor(scatterOp, rewriter);
+
+ // Scatter wait/signal fence, which are optional.
+ Value waitFence =
+ getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
+ Value signalFence = getOrCreateSignalFence(
+ loc, device, scatterOp.getResultTimepoint(), rewriter);
+
+ // Queue operation.
+ rewriter.create<IREE::IO::Parameters::ScatterOp>(
+ loc, device, queueAffinity, waitFence, signalFence, adaptor.getSource(),
+ adaptor.getSourceOffsets(), adaptor.getSourceLengths(),
+ adaptor.getTargetScopeAttr(), adaptor.getTargetKeysAttr(),
+ adaptor.getTargetOffsets());
+
+ rewriter.replaceOp(scatterOp, {signalFence});
+ return success();
+ }
+};
+
+} // namespace
+
+void populateStreamToIOParametersPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.insert<ParameterLoadOpPattern, ParameterReadOpPattern,
+ ParameterWriteOpPattern, ParameterGatherOpPattern,
+ ParameterScatterOpPattern>(typeConverter, context);
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.h
new file mode 100644
index 0000000..03c5f7b
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.h
@@ -0,0 +1,22 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_STREAMTOPARAMS_PATTERNS_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_STREAMTOPARAMS_PATTERNS_H_
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+void populateStreamToIOParametersPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_CONVERSION_STREAMTOPARAMS_PATTERNS_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel
new file mode 100644
index 0000000..06483fc
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "parameter_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/CMakeLists.txt
new file mode 100644
index 0000000..71e1a12
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "parameter_ops.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
new file mode 100644
index 0000000..7e1a810
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -0,0 +1,175 @@
+// RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: @parameterLoad
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence)
+func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-SAME: source("scope"::"key")[%c50_i64]
+ // CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
+ // CHECK-SAME: : !hal.buffer
+ %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ // CHECK: return %[[BUFFER]], %[[SIGNAL]]
+ return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterLoadNoScope
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence)
+func.func @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-SAME: source("key")[%c50_i64]
+ // CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
+ // CHECK-SAME: : !hal.buffer
+ %result, %result_timepoint = stream.parameter.load await(%wait) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ // CHECK: return %[[BUFFER]], %[[SIGNAL]]
+ return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterRead
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
+func.func @parameterRead(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: io_parameters.read<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-SAME: source("scope"::"key")[%c50_i64]
+ // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c100]
+ // CHECK-SAME: length(%c200)
+ %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ // CHECK: return %[[SIGNAL]]
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterWrite
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
+func.func @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: io_parameters.write<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-SAME: source(%[[SOURCE]] : !hal.buffer)[%c100]
+ // CHECK-SAME: target("scope"::"key")[%c50_i64]
+ // CHECK-SAME: length(%c200)
+ %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ // CHECK: return %[[SIGNAL]]
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterGather
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
+func.func @parameterGather(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
+ // CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer,
+ // CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !hal.buffer
+ %timepoint = stream.parameter.gather await(%wait) => {
+ "scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
+ "scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300},
+ "scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource<transient>{%c300}
+ } => !stream.timepoint
+ // CHECK: return %[[SIGNAL]]
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterGatherNoScope
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
+func.func @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
+ // CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer
+ %timepoint = stream.parameter.gather await(%wait) => {
+ "key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
+ "key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300}
+ } => !stream.timepoint
+ // CHECK: return %[[SIGNAL]]
+ return %timepoint : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @parameterScatter
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
+func.func @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
+ %c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
+ %c52_i64 = arith.constant 52 : i64
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c200 = arith.constant 200 : index
+ %c201 = arith.constant 201 : index
+ %c202 = arith.constant 202 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
+ // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
+ // CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !hal.buffer -> "scope"::"key0"[%c50_i64],
+ // CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64],
+ // CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64]
+ // CHECK-NEXT: }
+ %timepoint = stream.parameter.scatter await(%wait) => {
+ %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64],
+ %source[%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64],
+ %source[%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
+ } => !stream.timepoint
+ // CHECK: return %[[SIGNAL]]
+ return %timepoint : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel
new file mode 100644
index 0000000..6cbaf0b
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel
@@ -0,0 +1,114 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_tablegen_doc", "iree_td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["IOParametersOps.td"])
+
+iree_td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "IOParametersBase.td",
+ "IOParametersOps.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/IR:td_files",
+ "//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
+ "@llvm-project//mlir:FuncTdFiles",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "IR",
+ srcs = [
+ "IOParametersOps.cpp",
+ ],
+ hdrs = [
+ "IOParametersOps.h",
+ "IOParametersOps.h.inc",
+ ],
+ textual_hdrs = [
+ "IOParametersOps.cpp.inc",
+ ],
+ deps = [
+ ":IOParametersOpsGen",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Dialect/VM/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SideEffectInterfaces",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:TranslateLib",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "IOParametersDialect",
+ srcs = ["IOParametersDialect.cpp"],
+ hdrs = ["IOParametersDialect.h"],
+ deps = [
+ ":IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+ "//compiler/src/iree/compiler/Dialect/VM/Conversion",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters:io_parameters_imports",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "IOParametersOpsGen",
+ tbl_outs = [
+ (
+ ["--gen-op-decls"],
+ "IOParametersOps.h.inc",
+ ),
+ (
+ ["--gen-op-defs"],
+ "IOParametersOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "IOParametersOps.td",
+ deps = [":td_files"],
+)
+
+iree_tablegen_doc(
+ name = "IOParametersDialectDocGen",
+ tbl_outs = [
+ (
+ [
+ "--dialect=io_parameters",
+ "--gen-dialect-doc",
+ ],
+ "IOParametersDialect.md",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "IOParametersOps.td",
+ deps = [":td_files"],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/CMakeLists.txt
new file mode 100644
index 0000000..e4f000f
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/CMakeLists.txt
@@ -0,0 +1,81 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ IR
+ HDRS
+ "IOParametersOps.h"
+ "IOParametersOps.h.inc"
+ TEXTUAL_HDRS
+ "IOParametersOps.cpp.inc"
+ SRCS
+ "IOParametersOps.cpp"
+ DEPS
+ ::IOParametersOpsGen
+ LLVMSupport
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRSideEffectInterfaces
+ MLIRSupport
+ MLIRTransformUtils
+ MLIRTranslateLib
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::VM::IR
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ IOParametersDialect
+ HDRS
+ "IOParametersDialect.h"
+ SRCS
+ "IOParametersDialect.cpp"
+ DEPS
+ ::IR
+ LLVMSupport
+ MLIRFuncDialect
+ MLIRIR
+ MLIRParser
+ MLIRSupport
+ MLIRTransformUtils
+ iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::VM::Conversion
+ iree::compiler::Modules::IO::Parameters::Conversion::ParamsToVM
+ iree::compiler::Modules::IO::Parameters::Conversion::StreamToParams
+ iree::compiler::Modules::IO::Parameters::io_parameters_imports
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ IOParametersOpsGen
+ TD_FILE
+ "IOParametersOps.td"
+ OUTS
+ --gen-op-decls IOParametersOps.h.inc
+ --gen-op-defs IOParametersOps.cpp.inc
+)
+
+iree_tablegen_doc(
+ NAME
+ IOParametersDialectDocGen
+ TD_FILE
+ "IOParametersOps.td"
+ OUTS
+ --dialect=io_parameters --gen-dialect-doc IOParametersDialect.md
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersBase.td b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersBase.td
new file mode 100644
index 0000000..6a3d0f6
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersBase.td
@@ -0,0 +1,71 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_MODULES_IO_PARAMETERS_BASE
+#define IREE_DIALECT_MODULES_IO_PARAMETERS_BASE
+
+include "iree/compiler/Dialect/Util/IR/UtilBase.td"
+
+//===----------------------------------------------------------------------===//
+// External parameter resource management
+//===----------------------------------------------------------------------===//
+
+def IOParameters_Dialect : Dialect {
+ let name = "io_parameters";
+ let cppNamespace = "::mlir::iree_compiler::IREE::IO::Parameters";
+
+ let summary = [{
+ External parameter resource management APIs.
+ }];
+ let description = [{
+ Parameters are externalized storage for resources that are
+ asynchronously accessible and device-aware. Parameters can be read or
+ written on the same device timelines as the operations that consume or
+ produce them and with locality pinning to ensure memory doesn't need to
+ move. Parameters are referenced by a scope and a key, with the scope
+ being optional but strongly recommended as a way to distinguish sets of
+ parameters that may exist when multiple model parts are compiled together
+ and would otherwise collide.
+
+ Parameters are provided by a few operations implementing a virtual
+ interface and can support shared parameters (same storage used
+ in multiple contexts, or outliving a single instantiation in a context),
+ in-memory caches, memory-mapped files (including directly using the
+ mapped memory for execution when devices support it), `iree_hal_file_t`
+ usage for device-supported I/O, and parameter subsetting for things
+ like runtime sharding.
+
+ Alongside read(+load) and write operations gather and scatter allow for
+ batching of large numbers of reads and writes into/from single buffers.
+ For parameter providers that can batch operations this allows for a
+ handful (~1-4) of calls out to perform many more operations (~thousands).
+ Modeling the gather/scatter also gives us a point where we could extract
+ the mapping and use it to repack files/defrag memory in the future.
+
+ See `io_parameters.imports.mlir` for the full list of exported functions.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// IOParameters enums
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IOParameters types
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IOParameters op traits
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Base IOParameters op classes
+//===----------------------------------------------------------------------===//
+
+class IOParameters_Op<string mnemonic, list<Trait> traits = []> :
+ Op<IOParameters_Dialect, mnemonic, !listconcat(traits, [])> {}
+
+#endif // IREE_DIALECT_MODULES_IO_PARAMETERS_BASE
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.cpp
new file mode 100644
index 0000000..4533eb8
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.cpp
@@ -0,0 +1,99 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
+
+#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
+#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
+#include "iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.h"
+#include "iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.h"
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h"
+#include "iree/compiler/Modules/IO/Parameters/io_parameters.imports.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+namespace {
+
+// Used to control inlining behavior.
+struct IOParametersInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ bool isLegalToInline(Operation *call, Operation *callable,
+ bool wouldBeCloned) const final {
+ // Sure!
+ return true;
+ }
+ bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+ IRMapping &valueMapping) const final {
+ // Sure!
+ return true;
+ }
+
+ bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
+ IRMapping &valueMapping) const final {
+ // Sure!
+ return true;
+ }
+};
+
+class StreamToIOParametersConversionInterface
+ : public HALConversionDialectInterface {
+public:
+ using HALConversionDialectInterface::HALConversionDialectInterface;
+ void setupConversionTarget(ConversionTarget &target,
+ RewritePatternSet &patterns,
+ TypeConverter &typeConverter) const override {
+ populateStreamToIOParametersPatterns(getDialect()->getContext(), target,
+ typeConverter, patterns);
+ }
+};
+
+class IOParametersToVMConversionInterface
+ : public VMConversionDialectInterface {
+public:
+ using VMConversionDialectInterface::VMConversionDialectInterface;
+
+ OwningOpRef<mlir::ModuleOp> parseVMImportModule() const override {
+ return mlir::parseSourceString<mlir::ModuleOp>(
+ StringRef(iree_io_parameters_imports_create()->data,
+ iree_io_parameters_imports_create()->size),
+ getDialect()->getContext());
+ }
+
+ void
+ populateVMConversionPatterns(SymbolTable &importSymbols,
+ RewritePatternSet &patterns,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) const override {
+ conversionTarget
+ .addIllegalDialect<IREE::IO::Parameters::IOParametersDialect>();
+ populateIOParametersToVMPatterns(getDialect()->getContext(),
+ conversionTarget, typeConverter,
+ importSymbols, patterns);
+ }
+};
+
+} // namespace
+
+IOParametersDialect::IOParametersDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context,
+ TypeID::get<IOParametersDialect>()) {
+ addInterfaces<IOParametersInlinerInterface>();
+ addInterfaces<StreamToIOParametersConversionInterface>();
+ addInterfaces<IOParametersToVMConversionInterface>();
+
+#define GET_OP_LIST
+ addOperations<
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp.inc"
+ >();
+}
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h
new file mode 100644
index 0000000..40f2d42
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h
@@ -0,0 +1,23 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSDIALECT_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+class IOParametersDialect : public Dialect {
+public:
+ explicit IOParametersDialect(MLIRContext *context);
+ static StringRef getDialectNamespace() { return "io_parameters"; }
+};
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSDIALECT_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp
new file mode 100644
index 0000000..70eefc4
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp
@@ -0,0 +1,391 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::iree_compiler::IREE::IO::Parameters {
+
+//===----------------------------------------------------------------------===//
+// custom<ParameterReference>($scope, $key)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterReference(OpAsmParser &parser,
+ StringAttr &scopeAttr,
+ StringAttr &keyAttr) {
+ auto builder = parser.getBuilder();
+ StringAttr firstAttr;
+ if (failed(parser.parseCustomAttributeWithFallback(firstAttr,
+ builder.getNoneType()))) {
+ return failure();
+ }
+ if (failed(parser.parseOptionalColon())) {
+ keyAttr = firstAttr;
+ return success();
+ }
+ scopeAttr = firstAttr;
+ if (failed(parser.parseColon()) ||
+ failed(parser.parseCustomAttributeWithFallback(keyAttr,
+ builder.getNoneType()))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printParameterReference(OpAsmPrinter &p, Operation *op,
+ StringAttr scopeAttr, StringAttr keyAttr) {
+ if (scopeAttr) {
+ p << "\"" << scopeAttr.getValue() << "\"";
+ p << "::";
+ }
+ p << "\"" << keyAttr.getValue() << "\"";
+}
+
+//===----------------------------------------------------------------------===//
+// io_parameters.load
+//===----------------------------------------------------------------------===//
+
+void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // TODO(benvanik): fold hal.buffer.subspan on the result into parameters.
+}
+
+//===----------------------------------------------------------------------===//
+// io_parameters.read
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldReadOpTargetBufferSubspan : public OpRewritePattern<ReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ReadOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
+ auto newTargetBuffer = op.getTargetBuffer();
+ auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
+ newTargetBuffer.getDefiningOp())) {
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(), newSourceOffset);
+ newTargetBuffer = subspanOp.getSourceBuffer();
+ newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(), newTargetOffset);
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ op.getTargetBufferMutable().assign(newTargetBuffer);
+ op.getTargetOffsetMutable().assign(newTargetOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // TODO(benvanik): fold hal.buffer.subspan on the result into parameters.
+ results.insert<FoldReadOpTargetBufferSubspan>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// io_parameters.write
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldWriteOpSourceBufferSubspan : public OpRewritePattern<WriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(WriteOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceBuffer = op.getSourceBuffer();
+ auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
+ auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
+ newSourceBuffer.getDefiningOp())) {
+ newSourceBuffer = subspanOp.getSourceBuffer();
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(), newSourceOffset);
+ newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(),
+ rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ subspanOp.getLoc(), rewriter.getI64Type(),
+ subspanOp.getSourceOffset()),
+ newTargetOffset);
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceBufferMutable().assign(newSourceBuffer);
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ op.getTargetOffsetMutable().assign(newTargetOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void WriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldWriteOpSourceBufferSubspan>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ParameterGatherOperations>(
+// $source_scope, $source_keys, $source_offsets,
+// $target_buffer, type($target_buffer), $target_offsets, $target_lengths)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterGatherOperations(
+ OpAsmParser &parser, StringAttr &sourceScopeAttr, ArrayAttr &sourceKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ OpAsmParser::UnresolvedOperand &targetBuffer, Type &targetType,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetLengths) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> sourceKeyAttrs;
+ do {
+ StringAttr rowSourceScopeAttr;
+ StringAttr sourceKeyAttr;
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ OpAsmParser::UnresolvedOperand targetOffset;
+ OpAsmParser::UnresolvedOperand targetLength;
+ OpAsmParser::UnresolvedOperand rowTargetBuffer;
+ Type rowTargetType;
+ if (failed(parseParameterReference(parser, rowSourceScopeAttr,
+ sourceKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseRSquare()) || failed(parser.parseArrow()) ||
+ failed(parser.parseOperand(rowTargetBuffer)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(targetOffset)) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperand(targetLength)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(rowTargetType))) {
+ return failure();
+ }
+ if (!targetType) {
+ sourceScopeAttr = rowSourceScopeAttr;
+ targetBuffer = rowTargetBuffer;
+ targetType = rowTargetType;
+ } else if (rowSourceScopeAttr != sourceScopeAttr ||
+ rowTargetBuffer.name != targetBuffer.name ||
+ rowTargetType != targetType) {
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "each operation must use the same scope and target resource");
+ }
+ sourceKeyAttrs.push_back(sourceKeyAttr);
+ sourceOffsets.push_back(sourceOffset);
+ targetOffsets.push_back(targetOffset);
+ targetLengths.push_back(targetLength);
+ } while (succeeded(parser.parseOptionalComma()));
+ sourceKeysAttr = builder.getArrayAttr(sourceKeyAttrs);
+ return success();
+}
+
+static void printParameterGatherOperations(
+ OpAsmPrinter &p, Operation *op, StringAttr sourceScopeAttr,
+ ArrayAttr sourceKeysAttr, ValueRange sourceOffsets, Value targetBuffer,
+ Type targetType, ValueRange targetOffsets, ValueRange targetLengths) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceKeysAttr.getAsRange<StringAttr>(), sourceOffsets,
+ targetOffsets, targetLengths),
+ [&](std::tuple<StringAttr, Value, Value, Value> it) {
+ auto [sourceKeyAttr, sourceOffset, targetOffset, targetLength] = it;
+ printParameterReference(p, op, sourceScopeAttr, sourceKeyAttr);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << "] -> ";
+ p.printOperand(targetBuffer);
+ p << "[";
+ p.printOperand(targetOffset);
+ p << " for ";
+ p.printOperand(targetLength);
+ p << "] : ";
+ p.printType(targetType);
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
+// io_parameters.gather
+//===----------------------------------------------------------------------===//
+
+LogicalResult GatherOp::verify() {
+ GatherOp op = *this;
+ size_t expectedCount = op.getSourceKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getTargetOffsets().size() != expectedCount ||
+ op.getTargetLengths().size() != expectedCount) {
+ return op.emitOpError() << "requires that the source keys, target offsets, "
+ "and target lengths are all 1:1";
+ }
+ return success();
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // TODO(benvanik): find a good way of folding in subspans; tricky because if
+ // buffers differ across entries then we can't reassign.
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ParameterScatterOperations>(
+// $source_buffer, type($source_buffer), $source_offsets, $source_lengths,
+// $target_scope, $target_keys, $target_offsets)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterScatterOperations(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &sourceBuffer,
+ Type &sourceType,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceLengths,
+ StringAttr &targetScopeAttr, ArrayAttr &targetKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &targetOffsets) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> targetKeyAttrs;
+ do {
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ OpAsmParser::UnresolvedOperand sourceLength;
+ OpAsmParser::UnresolvedOperand rowSourceBuffer;
+ Type rowSourceType;
+ StringAttr rowTargetScopeAttr;
+ StringAttr targetKeyAttr;
+ OpAsmParser::UnresolvedOperand targetOffset;
+ if (failed(parser.parseOperand(rowSourceBuffer)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperand(sourceLength)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(rowSourceType)) ||
+ failed(parser.parseArrow()) ||
+ failed(parseParameterReference(parser, rowTargetScopeAttr,
+ targetKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(targetOffset)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ if (!sourceType) {
+ sourceBuffer = rowSourceBuffer;
+ sourceType = rowSourceType;
+ targetScopeAttr = rowTargetScopeAttr;
+ } else if (rowSourceBuffer.name != sourceBuffer.name ||
+ rowSourceType != sourceType ||
+ rowTargetScopeAttr != targetScopeAttr) {
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "each operation must use the same source resource and scope");
+ }
+ sourceOffsets.push_back(sourceOffset);
+ sourceLengths.push_back(sourceLength);
+ targetKeyAttrs.push_back(targetKeyAttr);
+ targetOffsets.push_back(targetOffset);
+ } while (succeeded(parser.parseOptionalComma()));
+ targetKeysAttr = builder.getArrayAttr(targetKeyAttrs);
+ return success();
+}
+
+static void printParameterScatterOperations(OpAsmPrinter &p, Operation *op,
+ Value sourceBuffer, Type sourceType,
+ ValueRange sourceOffsets,
+ ValueRange sourceLengths,
+ StringAttr targetScopeAttr,
+ ArrayAttr targetKeysAttr,
+ ValueRange targetOffsets) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceOffsets, sourceLengths,
+ targetKeysAttr.getAsRange<StringAttr>(), targetOffsets),
+ [&](std::tuple<Value, Value, StringAttr, Value> it) {
+ auto [sourceOffset, sourceLength, targetKeyAttr, targetOffset] = it;
+ p.printOperand(sourceBuffer);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << " for ";
+ p.printOperand(sourceLength);
+ p << "] : ";
+ p.printType(sourceType);
+ p << " -> ";
+ printParameterReference(p, op, targetScopeAttr, targetKeyAttr);
+ p << "[";
+ p.printOperand(targetOffset);
+ p << "]";
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
+// io_parameters.scatter
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScatterOp::verify() {
+ ScatterOp op = *this;
+ size_t expectedCount = op.getTargetKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getSourceLengths().size() != expectedCount ||
+ op.getTargetOffsets().size() != expectedCount) {
+ return op.emitOpError() << "requires that the source offsets, source "
+ "lengths, and target keys are all 1:1";
+ }
+ return success();
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // TODO(benvanik): find a good way of folding in subspans; tricky because if
+ // buffers differ across entries then we can't reassign.
+}
+
+} // namespace mlir::iree_compiler::IREE::IO::Parameters
+
+//===----------------------------------------------------------------------===//
+// TableGen definitions (intentionally last)
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h
new file mode 100644
index 0000000..118abfd
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h
@@ -0,0 +1,26 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSOPS_H_
+#define IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSOPS_H_
+
+#include <cstdint>
+
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.h.inc" // IWYU pragma: keep
+
+#endif // IREE_COMPILER_MODULES_IO_PARAMETERS_IR_IOPARAMETERSOPS_H_
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td
new file mode 100644
index 0000000..e4ce51b
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td
@@ -0,0 +1,237 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_MODULES_IO_PARAMETERS_OPS
+#define IREE_DIALECT_MODULES_IO_PARAMETERS_OPS
+
+include "iree/compiler/Dialect/HAL/IR/HALBase.td"
+include "iree/compiler/Dialect/Util/IR/UtilAttrs.td"
+include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
+include "iree/compiler/Modules/IO/Parameters/IR/IOParametersBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class IOParameters_PureOp<string mnemonic, list<Trait> traits = []> :
+ IOParameters_Op<mnemonic, !listconcat(traits, [Pure])>;
+
+//===----------------------------------------------------------------------===//
+// Parameter I/O
+//===----------------------------------------------------------------------===//
+
+def OpGroupParameterOps : OpDocGroup {
+ let summary = "Parameter I/O ops";
+ let description = "Ops parameter I/O.";
+}
+
+let opDocGroup = OpGroupParameterOps in {
+
+def IOParameters_LoadOp : IOParameters_Op<"load", [
+ Util_SizeAwareOp,
+]> {
+ let summary = [{reads a parameter from a parameter scope}];
+ let description = [{
+ Asynchronously reads a parameter from an external parameter provider and
+ returns the resulting buffer. Depending on the parameter and buffer types
+ this may alias existing cached storage or be directly mapped to the
+ parameter origin or result in a copy as if an allocate + read had been used.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ OptionalAttr<StrAttr>:$source_scope,
+ StrAttr:$source_key,
+ I64:$source_offset,
+ HAL_MemoryTypeBitfieldAttr:$memory_types,
+ HAL_BufferUsageBitfieldAttr:$buffer_usage,
+ HAL_DeviceSize:$length
+ );
+ let results = (outs
+ HAL_Buffer:$result
+ );
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `source` `(` custom<ParameterReference>($source_scope, $source_key) `)`
+ `` `[` $source_offset `]`
+ `type` `(` $memory_types `)`
+ `usage` `(` $buffer_usage `)`
+ `:` custom<SizeAwareType>(type($result), $length)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return getLength(); }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def IOParameters_ReadOp : IOParameters_Op<"read", []> {
+ let summary = [{reads a parameter from a parameter scope}];
+ let description = [{
+ Asynchronously reads a parameter from an external parameter provider into
+ the provided target buffer range.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ OptionalAttr<StrAttr>:$source_scope,
+ StrAttr:$source_key,
+ I64:$source_offset,
+ HAL_Buffer:$target_buffer,
+ HAL_DeviceSize:$target_offset,
+ HAL_DeviceSize:$length
+ );
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `source` `(` custom<ParameterReference>($source_scope, $source_key) `)`
+ `` `[` $source_offset `]`
+ `target` `(` $target_buffer `:` type($target_buffer) `)`
+ `` `[` $target_offset `]`
+ `length` `(` $length `)`
+ attr-dict-with-keyword
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def IOParameters_WriteOp : IOParameters_Op<"write", []> {
+ let summary = [{writes a parameter into a parameter scope}];
+ let description = [{
+ Asynchronously writes a parameter to an external parameter provider from
+ the provided source buffer range.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ HAL_Buffer:$source_buffer,
+ HAL_DeviceSize:$source_offset,
+ OptionalAttr<StrAttr>:$target_scope,
+ StrAttr:$target_key,
+ I64:$target_offset,
+ HAL_DeviceSize:$length
+ );
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `source` `(` $source_buffer `:` type($source_buffer) `)`
+ `` `[` $source_offset `]`
+ `target` `(` custom<ParameterReference>($target_scope, $target_key) `)`
+ `` `[` $target_offset `]`
+ `length` `(` $length `)`
+ attr-dict-with-keyword
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def IOParameters_GatherOp : IOParameters_Op<"gather", [
+ AttrSizedOperandSegments,
+]> {
+ let summary = [{gathers multiple parameters from a parameter scope}];
+ let description = [{
+ Asynchronously gathers one or more parameters into a single target buffer.
+ This is equivalent to one read per parameter but allows implementations that
+ can batch operations to do so without additional overhead.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ OptionalAttr<StrAttr>:$source_scope,
+ StrArrayAttr:$source_keys,
+ Variadic<I64>:$source_offsets,
+ HAL_Buffer:$target_buffer,
+ Variadic<HAL_DeviceSize>:$target_offsets,
+ Variadic<HAL_DeviceSize>:$target_lengths
+ );
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `{`
+ custom<ParameterGatherOperations>(
+ $source_scope, $source_keys, $source_offsets,
+ $target_buffer, type($target_buffer), $target_offsets, $target_lengths)
+ `}`
+ attr-dict-with-keyword
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+def IOParameters_ScatterOp : IOParameters_Op<"scatter", [
+ AttrSizedOperandSegments,
+]> {
+ let summary = [{scatters multiple parameters to a parameter scope}];
+ let description = [{
+ Asynchronously scatters one or more parameters from a single source buffer
+ into one or more parameters. This is equivalent to one write per parameter
+ but allows implementations that can batch operations to do so without
+ additional overhead.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ HAL_Buffer:$source_buffer,
+ Variadic<HAL_DeviceSize>:$source_offsets,
+ Variadic<HAL_DeviceSize>:$source_lengths,
+ OptionalAttr<StrAttr>:$target_scope,
+ StrArrayAttr:$target_keys,
+ Variadic<I64>:$target_offsets
+ );
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `{`
+ custom<ParameterScatterOperations>(
+ $source_buffer, type($source_buffer), $source_offsets, $source_lengths,
+ $target_scope, $target_keys, $target_offsets)
+ `}`
+ attr-dict-with-keyword
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+}
+
+} // OpGroupParameterOps
+
+#endif // IREE_DIALECT_MODULES_IO_PARAMETERS_OPS
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel
new file mode 100644
index 0000000..06483fc
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "parameter_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/CMakeLists.txt
new file mode 100644
index 0000000..29c1e20
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "parameter_ops.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir
new file mode 100644
index 0000000..1ab028d
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @parameterLoad
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence)
+func.func @parameterLoad(%device: !hal.device, %wait: !hal.fence, %signal: !hal.fence) {
+ // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
+ %affinity = arith.constant -1 : i64
+ // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 0
+ %offset = arith.constant 0 : i64
+ // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128
+ %length = arith.constant 128 : index
+ // CHECK: = io_parameters.load<%[[DEVICE]] : !hal.device>
+ // CHECK-SAME: affinity(%[[AFFINITY]])
+ // CHECK-SAME: wait(%[[WAIT]])
+ // CHECK-SAME: signal(%[[SIGNAL]])
+ // CHECK-SAME: source("scope"::"w0")[%[[OFFSET]]]
+ // CHECK-SAME: type("DeviceVisible|DeviceLocal")
+ // CHECK-SAME: usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
+ // CHECK-SAME: : !hal.buffer{%[[LENGTH]]}
+ %0 = io_parameters.load<%device : !hal.device>
+ affinity(%affinity)
+ wait(%wait)
+ signal(%signal)
+ source("scope"::"w0")[%offset]
+ type("DeviceVisible|DeviceLocal")
+ usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
+ : !hal.buffer{%length}
+ return
+}
+
+// -----
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir
new file mode 100644
index 0000000..13c57b3
--- /dev/null
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir
@@ -0,0 +1,67 @@
+vm.module @io_parameters {
+
+vm.import private @load(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %source_scope : !vm.buffer,
+ %source_key : !vm.buffer,
+ %source_offset : i64,
+ %target_queue_affinity : i64,
+ %target_memory_type : i32,
+ %target_buffer_usage : i32,
+ %length : i64
+) -> !vm.ref<!hal.buffer>
+
+vm.import private @read(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %source_scope : !vm.buffer,
+ %source_key : !vm.buffer,
+ %source_offset : i64,
+ %target_buffer : !vm.ref<!hal.buffer>,
+ %target_offset : i64,
+ %length : i64
+)
+
+vm.import private @write(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %target_scope : !vm.buffer,
+ %target_key : !vm.buffer,
+ %target_offset : i64,
+ %source_buffer : !vm.ref<!hal.buffer>,
+ %source_offset : i64,
+ %length : i64
+)
+
+vm.import private @gather(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %source_scope : !vm.buffer,
+ %target_buffer : !vm.ref<!hal.buffer>,
+ %key_table : !vm.buffer,
+ %key_data : !vm.buffer,
+ %spans : !vm.buffer
+)
+
+vm.import private @scatter(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %source_buffer : !vm.ref<!hal.buffer>,
+ %target_scope : !vm.buffer,
+ %key_table : !vm.buffer,
+ %key_data : !vm.buffer,
+ %spans : !vm.buffer
+)
+
+} // module
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 5394fc5..5d9ec18 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -79,6 +79,7 @@
"//compiler/src/iree/compiler/Modules/HAL/Inline/Transforms",
"//compiler/src/iree/compiler/Modules/HAL/Loader/IR:HALLoaderDialect",
"//compiler/src/iree/compiler/Modules/HAL/Loader/Transforms",
+ "//compiler/src/iree/compiler/Modules/IO/Parameters/IR:IOParametersDialect",
"//compiler/src/iree/compiler/Pipelines",
"//compiler/src/iree/compiler/Preprocessing:Passes",
"//llvm-external-projects/iree-dialects:IREEInputDialect",
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index 69aac54..6553405 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -107,6 +107,7 @@
iree::compiler::Modules::HAL::Inline::Transforms
iree::compiler::Modules::HAL::Loader::IR::HALLoaderDialect
iree::compiler::Modules::HAL::Loader::Transforms
+ iree::compiler::Modules::IO::Parameters::IR::IOParametersDialect
iree::compiler::Pipelines
iree::compiler::Preprocessing::Passes
PUBLIC
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
index 8cf0cda..7d19e14 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -28,6 +28,7 @@
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
#include "iree/compiler/Modules/HAL/Loader/IR/HALLoaderDialect.h"
+#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
#include "mlir/IR/Dialect.h"
namespace mlir {
@@ -41,13 +42,14 @@
IREE::HAL::HALDialect,
IREE::HAL::Inline::HALInlineDialect,
IREE::HAL::Loader::HALLoaderDialect,
+ IREE::IO::Parameters::IOParametersDialect,
+ IREE::Input::IREEInputDialect,
IREE::LinalgExt::IREELinalgExtDialect,
IREE::Stream::StreamDialect,
IREE::Util::UtilDialect,
IREE::VM::VMDialect,
IREE::VMVX::VMVXDialect,
- IREE::Vulkan::VulkanDialect,
- IREE::Input::IREEInputDialect>();
+ IREE::Vulkan::VulkanDialect>();
// clang-format on
// External models.
diff --git a/compiler/src/iree/compiler/Utils/IndexSet.h b/compiler/src/iree/compiler/Utils/IndexSet.h
index ca9b894..22f8ca7 100644
--- a/compiler/src/iree/compiler/Utils/IndexSet.h
+++ b/compiler/src/iree/compiler/Utils/IndexSet.h
@@ -30,6 +30,7 @@
memoizedIndices[value] = memoizedValue;
return memoizedValue;
}
+ Value get(APInt value) { return get(value.getSExtValue()); }
void populate(ValueRange values) {
for (auto value : values) {
diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml
index 7a746a5..35d05dd 100644
--- a/docs/website/mkdocs.yml
+++ b/docs/website/mkdocs.yml
@@ -152,8 +152,9 @@
- Check: "reference/mlir-dialects/Check.md"
- Flow: "reference/mlir-dialects/Flow.md"
- HAL: "reference/mlir-dialects/HAL.md"
- - HALInline: "reference/mlir-dialects/HALInline.md"
- - HALLoader: "reference/mlir-dialects/HALLoader.md"
+ - HAL/Inline: "reference/mlir-dialects/HALInline.md"
+ - HAL/Loader: "reference/mlir-dialects/HALLoader.md"
+ - IO/Parameters: "reference/mlir-dialects/IOParameters.md"
- Stream: "reference/mlir-dialects/Stream.md"
- Util: "reference/mlir-dialects/Util.md"
- VM: "reference/mlir-dialects/VM.md"
diff --git a/runtime/src/iree/io/BUILD.bazel b/runtime/src/iree/io/BUILD.bazel
index af3b9bf..868e3bf 100644
--- a/runtime/src/iree/io/BUILD.bazel
+++ b/runtime/src/iree/io/BUILD.bazel
@@ -25,3 +25,65 @@
"//runtime/src/iree/base/internal",
],
)
+
+iree_runtime_cc_library(
+ name = "parameter_index",
+ srcs = [
+ "parameter_index.c",
+ ],
+ hdrs = [
+ "parameter_index.h",
+ ],
+ deps = [
+ ":file_handle",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal",
+ "//runtime/src/iree/base/internal:synchronization",
+ ],
+)
+
+iree_runtime_cc_library(
+ name = "parameter_index_provider",
+ srcs = [
+ "parameter_index_provider.c",
+ ],
+ hdrs = [
+ "parameter_index_provider.h",
+ ],
+ deps = [
+ ":parameter_index",
+ ":parameter_provider",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/hal/utils:file_cache",
+ ],
+)
+
+iree_runtime_cc_library(
+ name = "parameter_provider",
+ srcs = [
+ "parameter_provider.c",
+ ],
+ hdrs = [
+ "parameter_provider.h",
+ ],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/hal",
+ ],
+)
+
+iree_runtime_cc_library(
+ name = "scope_map",
+ srcs = [
+ "scope_map.c",
+ ],
+ hdrs = [
+ "scope_map.h",
+ ],
+ deps = [
+ ":parameter_index",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal",
+ ],
+)
diff --git a/runtime/src/iree/io/CMakeLists.txt b/runtime/src/iree/io/CMakeLists.txt
index 95d6baf..cd74143 100644
--- a/runtime/src/iree/io/CMakeLists.txt
+++ b/runtime/src/iree/io/CMakeLists.txt
@@ -23,4 +23,62 @@
PUBLIC
)
+iree_cc_library(
+ NAME
+ parameter_index
+ HDRS
+ "parameter_index.h"
+ SRCS
+ "parameter_index.c"
+ DEPS
+ ::file_handle
+ iree::base
+ iree::base::internal
+ iree::base::internal::synchronization
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ parameter_index_provider
+ HDRS
+ "parameter_index_provider.h"
+ SRCS
+ "parameter_index_provider.c"
+ DEPS
+ ::parameter_index
+ ::parameter_provider
+ iree::base
+ iree::hal
+ iree::hal::utils::file_cache
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ parameter_provider
+ HDRS
+ "parameter_provider.h"
+ SRCS
+ "parameter_provider.c"
+ DEPS
+ iree::base
+ iree::hal
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ scope_map
+ HDRS
+ "scope_map.h"
+ SRCS
+ "scope_map.c"
+ DEPS
+ ::parameter_index
+ iree::base
+ iree::base::internal
+ PUBLIC
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/BUILD.bazel b/runtime/src/iree/io/formats/BUILD.bazel
new file mode 100644
index 0000000..522ca5d
--- /dev/null
+++ b/runtime/src/iree/io/formats/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/runtime/src/iree/io/formats/CMakeLists.txt b/runtime/src/iree/io/formats/CMakeLists.txt
new file mode 100644
index 0000000..df50aab
--- /dev/null
+++ b/runtime/src/iree/io/formats/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/io/formats/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/gguf/BUILD.bazel b/runtime/src/iree/io/formats/gguf/BUILD.bazel
new file mode 100644
index 0000000..879a34c
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/BUILD.bazel
@@ -0,0 +1,41 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_runtime_cc_library(
+ name = "gguf",
+ srcs = [
+ "gguf_format.c",
+ ],
+ hdrs = [
+ "gguf_format.h",
+ ],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/io:file_handle",
+ "//runtime/src/iree/io:parameter_index",
+ ],
+)
+
+iree_runtime_cc_test(
+ name = "gguf_format_test",
+ srcs = ["gguf_format_test.cc"],
+ tags = ["requires-filesystem"],
+ deps = [
+ ":gguf",
+ "//runtime/src/iree/base/internal:file_io",
+ "//runtime/src/iree/io/formats/gguf/testdata:gguf_files",
+ "//runtime/src/iree/testing:gtest",
+ "//runtime/src/iree/testing:gtest_main",
+ ],
+)
diff --git a/runtime/src/iree/io/formats/gguf/CMakeLists.txt b/runtime/src/iree/io/formats/gguf/CMakeLists.txt
new file mode 100644
index 0000000..c727aac
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/CMakeLists.txt
@@ -0,0 +1,42 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/io/formats/gguf/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ gguf
+ HDRS
+ "gguf_format.h"
+ SRCS
+ "gguf_format.c"
+ DEPS
+ iree::base
+ iree::io::file_handle
+ iree::io::parameter_index
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ gguf_format_test
+ SRCS
+ "gguf_format_test.cc"
+ DEPS
+ ::gguf
+ iree::base::internal::file_io
+ iree::io::formats::gguf::testdata::gguf_files
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "requires-filesystem"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/gguf/gguf_format.c b/runtime/src/iree/io/formats/gguf/gguf_format.c
new file mode 100644
index 0000000..d8f659f
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/gguf_format.c
@@ -0,0 +1,641 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/formats/gguf/gguf_format.h"
+
+#include <ctype.h>
+
+// File format:
+// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
+//
+// References:
+// https://github.com/ggerganov/ggml/blob/master/src/ggml.c
+// https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/gguf.py
+//
+// Unfortunately they don't encode tensor sizes so we need to carry a ton of
+// logic/tables for calculating it. We could avoid this by just saying every
+// tensor is the remaining length in the file but that'd be really confusing.
+//
+// Things that would improve the format:
+// - alignment of all fields in the header so that they can be directly accessed
+// (since strings and such are stored inline and have byte-length alignment
+// a single string early in the header ensures all other fields in all other
+// header data are unaligned, preventing us from directly overlaying structs)
+// - all variable-length header data should be in tables as with most other
+// formats designed for easy parsing/high performance I/O (ELF, etc) - strings
+// and dimensions as well as embedded metadata arrays should be offsets to
+// the data in the file instead of inlined
+// - tensor total size should be stored to allow for quickly partitioning files
+// without relying on the various data type/block structures (some of which
+// are only conditionally available and will change frequently)
+// - header total size (offset to tensor_data) should be stored to allow
+// skipping the header or quickly checking for truncated files (currently need
+// to correctly parse the entire header in order to find the tensor_data)
+// - metadata arrays should have a total length to make it easy to skip them;
+// since they can contain strings it's not possible given an array type and
+// element count to know how long it is without parsing all elements and
+// arrays of strings are even worse - fixed size header fields with externally
+// referenced variable-length data would make this easy
+//
+// The structs below are here for reference; some we use with slight
+// modifications for readability (easy to compare against reference code) but
+// others we just parse field-by-field due to the aforementioned nested
+// variable-sized craziness.
+
+#define GGUF_MAGIC 0x46554747
+#define GGUF_VERSION 3
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+enum ggml_type_e {
+ GGML_TYPE_F32 = 0,
+ GGML_TYPE_F16 = 1,
+ GGML_TYPE_Q4_0 = 2,
+ GGML_TYPE_Q4_1 = 3,
+ GGML_TYPE_Q5_0 = 6,
+ GGML_TYPE_Q5_1 = 7,
+ GGML_TYPE_Q8_0 = 8,
+ GGML_TYPE_Q8_1 = 9,
+ GGML_TYPE_Q2_K = 10,
+ GGML_TYPE_Q3_K = 11,
+ GGML_TYPE_Q4_K = 12,
+ GGML_TYPE_Q5_K = 13,
+ GGML_TYPE_Q6_K = 14,
+ GGML_TYPE_Q8_K = 15,
+ GGML_TYPE_I8 = 16,
+ GGML_TYPE_I16 = 17,
+ GGML_TYPE_I32 = 18,
+ GGML_TYPE_COUNT,
+};
+typedef uint32_t ggml_type_t;
+
+#define QK4_0 32
+typedef struct {
+ uint16_t d;
+ uint8_t qs[QK4_0 / 2];
+} block_q4_0;
+#define QK4_1 32
+typedef struct {
+ uint16_t d;
+ uint16_t m;
+ uint8_t qs[QK4_1 / 2];
+} block_q4_1;
+#define QK5_0 32
+typedef struct {
+ uint16_t d;
+ uint8_t qh[4];
+ uint8_t qs[QK5_0 / 2];
+} block_q5_0;
+#define QK5_1 32
+typedef struct {
+ uint16_t d;
+ uint16_t m;
+ uint8_t qh[4];
+ uint8_t qs[QK5_1 / 2];
+} block_q5_1;
+#define QK8_0 32
+typedef struct {
+ uint16_t d;
+ int8_t qs[QK8_0];
+} block_q8_0;
+#define QK8_1 32
+typedef struct {
+ float d;
+ float s;
+ int8_t qs[QK8_1];
+} block_q8_1;
+
+typedef struct {
+ int blck_size;
+ size_t type_size;
+} ggml_type_traits_t;
+static const ggml_type_traits_t ggml_type_traits[GGML_TYPE_COUNT] = {
+ [GGML_TYPE_I8] =
+ {
+ .blck_size = 1,
+ .type_size = sizeof(int8_t),
+ },
+ [GGML_TYPE_I16] =
+ {
+ .blck_size = 1,
+ .type_size = sizeof(int16_t),
+ },
+ [GGML_TYPE_I32] =
+ {
+ .blck_size = 1,
+ .type_size = sizeof(int32_t),
+ },
+ [GGML_TYPE_F32] =
+ {
+ .blck_size = 1,
+ .type_size = sizeof(float),
+ },
+ [GGML_TYPE_F16] =
+ {
+ .blck_size = 1,
+ .type_size = sizeof(uint16_t),
+ },
+ [GGML_TYPE_Q4_0] =
+ {
+ .blck_size = QK4_0,
+ .type_size = sizeof(block_q4_0),
+ },
+ [GGML_TYPE_Q4_1] =
+ {
+ .blck_size = QK4_1,
+ .type_size = sizeof(block_q4_1),
+ },
+ [GGML_TYPE_Q5_0] =
+ {
+ .blck_size = QK5_0,
+ .type_size = sizeof(block_q5_0),
+ },
+ [GGML_TYPE_Q5_1] =
+ {
+ .blck_size = QK5_1,
+ .type_size = sizeof(block_q5_1),
+ },
+ [GGML_TYPE_Q8_0] =
+ {
+ .blck_size = QK8_0,
+ .type_size = sizeof(block_q8_0),
+ },
+ [GGML_TYPE_Q8_1] =
+ {
+ .blck_size = QK8_1,
+ .type_size = sizeof(block_q8_1),
+ },
+};
+
+enum gguf_metadata_value_type_e {
+ GGUF_METADATA_VALUE_TYPE_UINT8 = 0,
+ GGUF_METADATA_VALUE_TYPE_INT8 = 1,
+ GGUF_METADATA_VALUE_TYPE_UINT16 = 2,
+ GGUF_METADATA_VALUE_TYPE_INT16 = 3,
+ GGUF_METADATA_VALUE_TYPE_UINT32 = 4,
+ GGUF_METADATA_VALUE_TYPE_INT32 = 5,
+ GGUF_METADATA_VALUE_TYPE_FLOAT32 = 6,
+ GGUF_METADATA_VALUE_TYPE_BOOL = 7,
+ GGUF_METADATA_VALUE_TYPE_STRING = 8,
+ GGUF_METADATA_VALUE_TYPE_ARRAY = 9,
+ GGUF_METADATA_VALUE_TYPE_UINT64 = 10,
+ GGUF_METADATA_VALUE_TYPE_INT64 = 11,
+ GGUF_METADATA_VALUE_TYPE_FLOAT64 = 12,
+};
+typedef uint32_t gguf_metadata_value_type_t;
+static const iree_host_size_t gguf_metadata_value_type_sizes[] = {
+ [GGUF_METADATA_VALUE_TYPE_UINT8] = sizeof(uint8_t),
+ [GGUF_METADATA_VALUE_TYPE_INT8] = sizeof(int8_t),
+ [GGUF_METADATA_VALUE_TYPE_UINT16] = sizeof(uint16_t),
+ [GGUF_METADATA_VALUE_TYPE_INT16] = sizeof(int16_t),
+ [GGUF_METADATA_VALUE_TYPE_UINT32] = sizeof(uint32_t),
+ [GGUF_METADATA_VALUE_TYPE_INT32] = sizeof(int32_t),
+ [GGUF_METADATA_VALUE_TYPE_FLOAT32] = sizeof(float),
+ [GGUF_METADATA_VALUE_TYPE_BOOL] = sizeof(bool),
+ [GGUF_METADATA_VALUE_TYPE_STRING] = 0,
+ [GGUF_METADATA_VALUE_TYPE_ARRAY] = 0,
+ [GGUF_METADATA_VALUE_TYPE_UINT64] = sizeof(uint64_t),
+ [GGUF_METADATA_VALUE_TYPE_INT64] = sizeof(int64_t),
+ [GGUF_METADATA_VALUE_TYPE_FLOAT64] = sizeof(double),
+};
+
+// typedef struct {
+// uint64_t len;
+// char string[len];
+// } gguf_string_t;
+
+// NOTE: storage of string/value has interior variable length data (ew).
+// union gguf_metadata_value_t {
+// uint8_t uint8;
+// int8_t int8;
+// uint16_t uint16;
+// int16_t int16;
+// uint32_t uint32;
+// int32_t int32;
+// float float32;
+// uint64_t uint64;
+// int64_t int64;
+// double float64;
+// bool bool_;
+// gguf_string_t string;
+// struct {
+// gguf_metadata_value_type_t type;
+// uint64_t len;
+// gguf_metadata_value_t array[len];
+// } array;
+// };
+typedef union {
+ uint8_t uint8;
+ int8_t int8;
+ uint16_t uint16;
+ int16_t int16;
+ uint32_t uint32;
+ int32_t int32;
+ float float32;
+ uint64_t uint64;
+ int64_t int64;
+ double float64;
+ bool bool_;
+ iree_string_view_t string;
+ struct {
+ gguf_metadata_value_type_t type;
+ uint64_t len;
+ // Arrays ignored for now - we just skip them.
+ // gguf_metadata_value_t array[/*len*/];
+ } array;
+} gguf_metadata_value_t;
+
+// NOTE: value.string/value.array has interior variable length data (ew).
+// struct gguf_metadata_kv_t {
+// gguf_string_t key;
+// gguf_metadata_value_type_t value_type;
+// gguf_metadata_value_t value;
+// };
+typedef struct {
+ iree_string_view_t key;
+ gguf_metadata_value_type_t value_type;
+ gguf_metadata_value_t value;
+} gguf_metadata_kv_t;
+
+// NOTE: metadata_kv has interior variable length data (ew).
+// struct gguf_header_t {
+// uint32_t magic;
+// uint32_t version;
+// uint64_t tensor_count;
+// uint64_t metadata_kv_count;
+// gguf_metadata_kv_t metadata_kv[metadata_kv_count];
+// };
+
+// NOTE: storage has interior variable length data (ew).
+// struct gguf_tensor_info_t {
+// gguf_string_t name;
+// uint32_t n_dimensions;
+// uint64_t dimensions[n_dimensions];
+// ggml_type_t type;
+// uint64_t offset;
+// };
+typedef struct {
+ iree_string_view_t name;
+ uint32_t n_dimensions;
+ const uint64_t* dimensions; // n_dimensions
+ ggml_type_t type;
+ uint64_t offset;
+} gguf_tensor_info_t;
+
+// struct gguf_file_t {
+// gguf_header_t header;
+// gguf_tensor_info_t tensor_infos[header.tensor_count];
+// uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)];
+// uint8_t tensor_data[];
+// };
+
+typedef struct iree_io_gguf_parser_t {
+ // Handle of the file being parsed.
+ iree_io_file_handle_t* file_handle;
+ // Index being appended to during parsing.
+ iree_io_parameter_index_t* index;
+ // Default value GGUF_DEFAULT_ALIGNMENT or the general.alignment kv value.
+ // The file tensor_data must be aligned to this value as must every tensor
+ // contained within it.
+ uint32_t alignment;
+ // Offset of the tensor_data file field. A 0 value (file offset 0) indicates
+ // the tensor_data offset has not been calculated yet. GGUF is nested variable
+ // length structs and unfortunately is not possible to scan in a single pass.
+ // Based off of the origin of the file.
+ uint64_t tensor_data_offset;
+ // Total available tensor data capacity. A 0 value indicates the
+ // tensor_data_size (and the required tensor_data_offset) have not been
+ // calculated yet.
+ uint64_t tensor_data_size;
+} iree_io_gguf_parser_t;
+
+static inline uint64_t iree_align_uint64(uint64_t value, uint64_t alignment) {
+ return (value + (alignment - 1)) & ~(alignment - 1);
+}
+
+static uint64_t iree_io_gguf_calculate_storage_size(
+ const gguf_tensor_info_t* tensor_info) {
+ uint64_t element_count = 1;
+ for (uint32_t i = 0; i < tensor_info->n_dimensions; ++i) {
+ element_count *= tensor_info->dimensions[i];
+ }
+ const ggml_type_traits_t type_traits = ggml_type_traits[tensor_info->type];
+ return (element_count * type_traits.type_size) / type_traits.blck_size;
+}
+
+static iree_status_t iree_io_gguf_parse_value(iree_const_byte_span_t* contents,
+ iree_host_size_t length,
+ void* out_value) {
+ if (contents->data_length < length) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "file buffer underrun parsing %" PRIhsz " byte value", length);
+ }
+ memcpy(out_value, contents->data, length);
+ contents->data += length;
+ contents->data_length -= length;
+ return iree_ok_status();
+}
+static iree_status_t iree_io_gguf_parse_uint32(iree_const_byte_span_t* contents,
+ uint32_t* out_value) {
+ return iree_io_gguf_parse_value(contents, sizeof(*out_value), out_value);
+}
+static iree_status_t iree_io_gguf_parse_uint64(iree_const_byte_span_t* contents,
+ uint64_t* out_value) {
+ return iree_io_gguf_parse_value(contents, sizeof(*out_value), out_value);
+}
+
+static iree_status_t iree_io_gguf_parse_array(iree_const_byte_span_t* contents,
+ uint64_t element_count,
+ uint64_t element_size,
+ const uint8_t** out_base_ptr) {
+ uint64_t total_length = element_count * element_size;
+ if (total_length > IREE_HOST_SIZE_MAX) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "attempting to load a 64-bit file on a 32-bit arch "
+ "(out of bounds array length)");
+ }
+ if (contents->data_length < (iree_host_size_t)total_length) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "file buffer underrun parsing array");
+ }
+ *out_base_ptr = contents->data;
+ contents->data += (iree_host_size_t)total_length;
+ contents->data_length -= (iree_host_size_t)total_length;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_io_gguf_parse_string(iree_const_byte_span_t* contents,
+ iree_string_view_t* out_value) {
+ uint64_t length = 0;
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_uint64(contents, &length));
+ if (length > IREE_HOST_SIZE_MAX) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "attempting to load a 64-bit file on a 32-bit arch "
+ "(out of bounds string length)");
+ }
+ out_value->size = (iree_host_size_t)length;
+ return iree_io_gguf_parse_array(contents, length, sizeof(char),
+ (const uint8_t**)&out_value->data);
+}
+
+static iree_status_t iree_io_gguf_skip_metadata_array(
+ iree_const_byte_span_t* contents, gguf_metadata_value_type_t value_type,
+ uint64_t count) {
+ switch (value_type) {
+ default:
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unsupported metadata value type %u", value_type);
+ case GGUF_METADATA_VALUE_TYPE_UINT8:
+ case GGUF_METADATA_VALUE_TYPE_INT8:
+ case GGUF_METADATA_VALUE_TYPE_UINT16:
+ case GGUF_METADATA_VALUE_TYPE_INT16:
+ case GGUF_METADATA_VALUE_TYPE_UINT32:
+ case GGUF_METADATA_VALUE_TYPE_INT32:
+ case GGUF_METADATA_VALUE_TYPE_FLOAT32:
+ case GGUF_METADATA_VALUE_TYPE_BOOL:
+ case GGUF_METADATA_VALUE_TYPE_UINT64:
+ case GGUF_METADATA_VALUE_TYPE_INT64:
+ case GGUF_METADATA_VALUE_TYPE_FLOAT64: {
+ const uint8_t* values = NULL;
+ return iree_io_gguf_parse_array(
+ contents, count, gguf_metadata_value_type_sizes[value_type], &values);
+ }
+ case GGUF_METADATA_VALUE_TYPE_STRING: {
+ iree_string_view_t value = iree_string_view_empty();
+ for (uint64_t i = 0; i < count; ++i) {
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_string(contents, &value));
+ }
+ return iree_ok_status();
+ }
+ case GGUF_METADATA_VALUE_TYPE_ARRAY:
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "nested arrays not supported in gguf");
+ }
+}
+
+static iree_status_t iree_io_gguf_parse_metadata_value(
+ iree_const_byte_span_t* contents, gguf_metadata_value_type_t value_type,
+ gguf_metadata_value_t* out_value) {
+ switch (value_type) {
+ default:
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unsupported metadata value type %u", value_type);
+ case GGUF_METADATA_VALUE_TYPE_UINT8:
+ case GGUF_METADATA_VALUE_TYPE_INT8:
+ case GGUF_METADATA_VALUE_TYPE_UINT16:
+ case GGUF_METADATA_VALUE_TYPE_INT16:
+ case GGUF_METADATA_VALUE_TYPE_UINT32:
+ case GGUF_METADATA_VALUE_TYPE_INT32:
+ case GGUF_METADATA_VALUE_TYPE_FLOAT32:
+ case GGUF_METADATA_VALUE_TYPE_BOOL:
+ case GGUF_METADATA_VALUE_TYPE_UINT64:
+ case GGUF_METADATA_VALUE_TYPE_INT64:
+ case GGUF_METADATA_VALUE_TYPE_FLOAT64:
+ return iree_io_gguf_parse_value(
+ contents, gguf_metadata_value_type_sizes[value_type], out_value);
+ case GGUF_METADATA_VALUE_TYPE_STRING:
+ return iree_io_gguf_parse_string(contents, &out_value->string);
+ case GGUF_METADATA_VALUE_TYPE_ARRAY: {
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint32(contents, &out_value->array.type));
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint64(contents, &out_value->array.len));
+ // We don't support arrays right now because they require allocation due
+ // to the variable length nature of things. We do still have to calculate
+ // the total size which is annoying due to the nested variable-length
+ // strings.
+ return iree_io_gguf_skip_metadata_array(contents, out_value->array.type,
+ out_value->array.len);
+ }
+ }
+}
+
+typedef iree_status_t(IREE_API_PTR* iree_io_gguf_metadata_kv_callback_fn_t)(
+ void* user_data, const gguf_metadata_kv_t* kv);
+static iree_status_t iree_io_gguf_enumerate_metadata_kv(
+ iree_const_byte_span_t* contents, uint64_t count,
+ iree_io_gguf_metadata_kv_callback_fn_t callback, void* user_data) {
+ for (uint64_t i = 0; i < count; ++i) {
+ gguf_metadata_kv_t kv = {0};
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_string(contents, &kv.key));
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_uint32(contents, &kv.value_type));
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_metadata_value(contents, kv.value_type, &kv.value));
+ IREE_RETURN_IF_ERROR(callback(user_data, &kv));
+ }
+ return iree_ok_status();
+}
+
+typedef iree_status_t(IREE_API_PTR* iree_io_gguf_tensor_info_callback_fn_t)(
+ void* user_data, const gguf_tensor_info_t* tensor_info);
+static iree_status_t iree_io_gguf_enumerate_tensor_info(
+ iree_const_byte_span_t* contents, uint64_t count,
+ iree_io_gguf_tensor_info_callback_fn_t callback, void* user_data) {
+ for (uint64_t i = 0; i < count; ++i) {
+ gguf_tensor_info_t tensor_info = {0};
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_string(contents, &tensor_info.name));
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint32(contents, &tensor_info.n_dimensions));
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_array(
+ contents, tensor_info.n_dimensions, sizeof(tensor_info.dimensions[0]),
+ (const uint8_t**)&tensor_info.dimensions));
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint32(contents, &tensor_info.type));
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint64(contents, &tensor_info.offset));
+ if (callback) {
+ IREE_RETURN_IF_ERROR(callback(user_data, &tensor_info));
+ }
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_io_gguf_parse_metadata(void* user_data,
+ const gguf_metadata_kv_t* kv) {
+ iree_io_gguf_parser_t* parser = (iree_io_gguf_parser_t*)user_data;
+ if (iree_string_view_equal(kv->key, IREE_SV("general.alignment"))) {
+ if (kv->value_type != GGUF_METADATA_VALUE_TYPE_UINT32) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "general.alignment metadata value must be uint32");
+ }
+ parser->alignment = kv->value.uint32;
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_io_gguf_append_tensor_info(
+ void* user_data, const gguf_tensor_info_t* tensor_info) {
+ iree_io_gguf_parser_t* parser = (iree_io_gguf_parser_t*)user_data;
+
+ // Unfortunately (I've said that a lot here?) the total size of the tensor is
+ // not stored and as such we need to calculate it based on the metadata we
+ // have. If they just included the size we wouldn't even have to care about
+ // data type or tensor dimensions and not need to handle the ever-growing list
+ // of hard-coded format types.
+ uint64_t storage_size = iree_io_gguf_calculate_storage_size(tensor_info);
+
+ // Verify the range is within tensor data bounds.
+ uint64_t begin = tensor_info->offset;
+ uint64_t end = begin + storage_size;
+ if (begin > end || end > parser->tensor_data_size) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "entry has data offsets outside of the "
+ "available data (begin=%" PRIu64 ", end=%" PRIu64
+ ", available=%" PRIu64 ")",
+ begin, end, parser->tensor_data_size);
+ }
+
+ // Add entry to the index.
+ iree_io_parameter_index_entry_t entry = {
+ .key = tensor_info->name,
+ .metadata = iree_const_byte_span_empty(),
+ .file_handle = parser->file_handle,
+ .offset = parser->tensor_data_offset + begin,
+ .length = storage_size,
+ };
+ return iree_io_parameter_index_add(parser->index, &entry);
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parse_gguf_index_from_memory(
+ iree_io_file_handle_t* file_handle, iree_const_byte_span_t file_contents,
+ iree_io_parameter_index_t* index) {
+ // Read the header enough to check for file validity and version.
+ // Unfortunately the format has a variable-length header (vs being
+ // table-based) and that means we have to actually parse the header fully
+ // (including all nested variable-length elements) in order to even know if
+ // the whole header is present or where data lives. Yuck.
+ iree_const_byte_span_t contents = file_contents;
+ uint32_t magic = 0;
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_uint32(&contents, &magic));
+ if (magic != GGUF_MAGIC) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "GGUF file magic missing or invalid %08X; expected %08X", magic,
+ GGUF_MAGIC);
+ }
+ uint32_t version = 0;
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_uint32(&contents, &version));
+ if (version != GGUF_VERSION) {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "GGUF format version %u is unsupported; expected version %u", version,
+ GGUF_VERSION);
+ }
+ uint64_t tensor_count = 0;
+ IREE_RETURN_IF_ERROR(iree_io_gguf_parse_uint64(&contents, &tensor_count));
+ uint64_t metadata_kv_count = 0;
+ IREE_RETURN_IF_ERROR(
+ iree_io_gguf_parse_uint64(&contents, &metadata_kv_count));
+
+ // If there are no tensors then no-op the parse. Probably not what the user
+ // wanted but it's legal.
+ if (tensor_count == 0) return iree_ok_status();
+
+ iree_io_gguf_parser_t parser = {
+ .file_handle = file_handle,
+ .index = index,
+ .alignment = GGUF_DEFAULT_ALIGNMENT, // may be overridden
+ .tensor_data_offset = 0, // to be calculated
+ .tensor_data_size = 0, // to be calculated
+ };
+
+ // Scope data to the remainder of the file and enumerate all metadata pairs.
+ // Upon return the contents will start immediately after the header and at
+ // the start of tensor info.
+ IREE_RETURN_IF_ERROR(iree_io_gguf_enumerate_metadata_kv(
+ &contents, metadata_kv_count, iree_io_gguf_parse_metadata, &parser));
+
+ // Scan forward through the tensor info to find where the tensor data base
+ // offset is in the file. Unfortunately GGUF was designed without this offset
+ // and because tensor info is variable length we cannot determine absolute
+ // file offsets without doing two scans.
+ iree_const_byte_span_t tensor_info_contents = contents;
+ IREE_RETURN_IF_ERROR(iree_io_gguf_enumerate_tensor_info(
+ &tensor_info_contents, tensor_count, NULL, &parser));
+
+ // Calculate where the tensor data begins in the file. This respects the
+ // default alignment or the general.alignment specified by the file.
+ parser.tensor_data_offset = iree_align_uint64(
+ (uint64_t)(tensor_info_contents.data - file_contents.data),
+ parser.alignment);
+ parser.tensor_data_size =
+ file_contents.data_length - parser.tensor_data_offset;
+
+ // Scan forward through the tensor info now that we know the tensor data
+ // offset and add the tensor entries.
+ IREE_RETURN_IF_ERROR(iree_io_gguf_enumerate_tensor_info(
+ &contents, tensor_count, iree_io_gguf_append_tensor_info, &parser));
+
+ return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parse_gguf_index(
+ iree_io_file_handle_t* file_handle, iree_io_parameter_index_t* index) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Today we only support memory files.
+ // TODO(benvanik): support iree_io_stream_t wrapping for parsing the index.
+ if (iree_io_file_handle_type(file_handle) !=
+ IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "non-memory gguf files not yet supported");
+ }
+ iree_byte_span_t host_allocation =
+ iree_io_file_handle_primitive(file_handle).value.host_allocation;
+
+ iree_status_t status = iree_io_parse_gguf_index_from_memory(
+ file_handle,
+ iree_make_const_byte_span(host_allocation.data,
+ host_allocation.data_length),
+ index);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/io/formats/gguf/gguf_format.h b/runtime/src/iree/io/formats/gguf/gguf_format.h
new file mode 100644
index 0000000..0e49ad6
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/gguf_format.h
@@ -0,0 +1,29 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_FORMATS_GGUF_GGUF_FORMAT_H_
+#define IREE_IO_FORMATS_GGUF_GGUF_FORMAT_H_
+
+#include "iree/base/api.h"
+#include "iree/io/file_handle.h"
+#include "iree/io/parameter_index.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Parses a .gguf file and merges its contained resources into |index|.
+//
+// Specification:
+// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
+IREE_API_EXPORT iree_status_t iree_io_parse_gguf_index(
+ iree_io_file_handle_t* file_handle, iree_io_parameter_index_t* index);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_FORMATS_GGUF_GGUF_FORMAT_H_
diff --git a/runtime/src/iree/io/formats/gguf/gguf_format_test.cc b/runtime/src/iree/io/formats/gguf/gguf_format_test.cc
new file mode 100644
index 0000000..368d4b6
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/gguf_format_test.cc
@@ -0,0 +1,105 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/formats/gguf/gguf_format.h"
+
+#include "iree/base/internal/file_io.h"
+#include "iree/io/formats/gguf/testdata/gguf_files.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace {
+
+static iree_io_file_handle_t* OpenTestFile(const char* name) {
+ const struct iree_file_toc_t* file_toc = iree_io_gguf_files_create();
+ for (size_t i = 0; i < iree_io_gguf_files_size(); ++i) {
+ if (strcmp(file_toc[i].name, name) == 0) {
+ iree_io_file_handle_t* file_handle = NULL;
+ IREE_CHECK_OK(iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_READ,
+ iree_make_byte_span((void*)file_toc[i].data, file_toc[i].size),
+ iree_io_file_handle_release_callback_null(), iree_allocator_system(),
+ &file_handle));
+ return file_handle;
+ }
+ }
+ IREE_CHECK_OK(iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "test file `%s` not found embedded into test binary", name));
+ return NULL;
+}
+
+TEST(GgufFormatTest, Empty) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("empty.gguf");
+ IREE_ASSERT_OK(iree_io_parse_gguf_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ iree_io_parameter_index_release(index);
+}
+
+TEST(GgufFormatTest, SingleTensor) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("single.gguf");
+ IREE_ASSERT_OK(iree_io_parse_gguf_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ const iree_io_parameter_index_entry_t* entry0 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor0"), &entry0));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor0"), entry0->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry0->metadata));
+ EXPECT_EQ(entry0->offset, 384);
+ EXPECT_EQ(entry0->length, 16);
+
+ iree_io_parameter_index_release(index);
+}
+
+TEST(GgufFormatTest, MultipleTensors) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("multiple.gguf");
+ IREE_ASSERT_OK(iree_io_parse_gguf_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ const iree_io_parameter_index_entry_t* entry0 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor0"), &entry0));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor0"), entry0->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry0->metadata));
+ EXPECT_EQ(entry0->offset, 448);
+ EXPECT_EQ(entry0->length, 16);
+
+ const iree_io_parameter_index_entry_t* entry1 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor1"), &entry1));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor1"), entry1->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry1->metadata));
+ EXPECT_EQ(entry1->offset, 512);
+ EXPECT_EQ(entry1->length, 8);
+
+ const iree_io_parameter_index_entry_t* entry2 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor2"), &entry2));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor2"), entry2->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry2->metadata));
+ EXPECT_EQ(entry2->offset, 576);
+ EXPECT_EQ(entry2->length, 48);
+
+ iree_io_parameter_index_release(index);
+}
+
+} // namespace
+} // namespace iree
diff --git a/runtime/src/iree/io/formats/gguf/testdata/BUILD.bazel b/runtime/src/iree/io/formats/gguf/testdata/BUILD.bazel
new file mode 100644
index 0000000..ce3c4cc
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/BUILD.bazel
@@ -0,0 +1,27 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+c_embed_data(
+ name = "gguf_files",
+ testonly = True,
+ srcs = [
+ "empty.gguf",
+ "multiple.gguf",
+ "single.gguf",
+ ],
+ c_file_output = "gguf_files.c",
+ flatten = True,
+ h_file_output = "gguf_files.h",
+ identifier = "iree_io_gguf_files",
+)
diff --git a/runtime/src/iree/io/formats/gguf/testdata/CMakeLists.txt b/runtime/src/iree/io/formats/gguf/testdata/CMakeLists.txt
new file mode 100644
index 0000000..e2c6f7b
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/CMakeLists.txt
@@ -0,0 +1,31 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/io/formats/gguf/testdata/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_c_embed_data(
+ NAME
+ gguf_files
+ SRCS
+ "empty.gguf"
+ "multiple.gguf"
+ "single.gguf"
+ C_FILE_OUTPUT
+ "gguf_files.c"
+ H_FILE_OUTPUT
+ "gguf_files.h"
+ IDENTIFIER
+ "iree_io_gguf_files"
+ TESTONLY
+ FLATTEN
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/gguf/testdata/empty.gguf b/runtime/src/iree/io/formats/gguf/testdata/empty.gguf
new file mode 100644
index 0000000..7723f7d
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/empty.gguf
Binary files differ
diff --git a/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py b/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py
new file mode 100644
index 0000000..31430a1
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/gguf.py
+#
+# To regenerate:
+# $ pip install gguf
+# $ cd runtime/src/iree/io/formats/gguf/testdata/
+# $ ./generate_gguf_files.py
+
+import numpy as np
+from gguf import GGUFWriter
+
+
+def save_file(tensors, path):
+ writer = GGUFWriter(path, "generic")
+
+ writer.add_architecture()
+ writer.add_custom_alignment(64)
+
+ writer.add_uint32("metadata_uint32", 42)
+ writer.add_string("metadata_str", "hello")
+ writer.add_array("metadata_strs", ["a", "b", "c"])
+
+ for key, value in tensors.items():
+ writer.add_tensor(key, value)
+
+ writer.write_header_to_file()
+ writer.write_kv_data_to_file()
+ writer.write_tensors_to_file()
+
+ writer.close()
+
+
+# no tensors
+save_file({}, "empty.gguf")
+
+# single tensor
+save_file({"tensor0": np.ones((2, 2), dtype=np.float32)}, "single.gguf")
+
+# multiple tensors
+save_file(
+ {
+ "tensor0": np.ones((2, 2), dtype=np.float32),
+ "tensor1": np.ones((1, 2), dtype=np.float32),
+ "tensor2": np.ones((4, 3), dtype=np.float32),
+ },
+ "multiple.gguf",
+)
diff --git a/runtime/src/iree/io/formats/gguf/testdata/multiple.gguf b/runtime/src/iree/io/formats/gguf/testdata/multiple.gguf
new file mode 100644
index 0000000..53232d3
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/multiple.gguf
Binary files differ
diff --git a/runtime/src/iree/io/formats/gguf/testdata/single.gguf b/runtime/src/iree/io/formats/gguf/testdata/single.gguf
new file mode 100644
index 0000000..006bef8
--- /dev/null
+++ b/runtime/src/iree/io/formats/gguf/testdata/single.gguf
Binary files differ
diff --git a/runtime/src/iree/io/formats/safetensors/BUILD.bazel b/runtime/src/iree/io/formats/safetensors/BUILD.bazel
new file mode 100644
index 0000000..a0d6a5e
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/BUILD.bazel
@@ -0,0 +1,41 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_runtime_cc_library(
+ name = "safetensors",
+ srcs = [
+ "safetensors_format.c",
+ ],
+ hdrs = [
+ "safetensors_format.h",
+ ],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/io:file_handle",
+ "//runtime/src/iree/io:parameter_index",
+ ],
+)
+
+iree_runtime_cc_test(
+ name = "safetensors_format_test",
+ srcs = ["safetensors_format_test.cc"],
+ tags = ["requires-filesystem"],
+ deps = [
+ ":safetensors",
+ "//runtime/src/iree/base/internal:file_io",
+ "//runtime/src/iree/io/formats/safetensors/testdata:safetensors_files",
+ "//runtime/src/iree/testing:gtest",
+ "//runtime/src/iree/testing:gtest_main",
+ ],
+)
diff --git a/runtime/src/iree/io/formats/safetensors/CMakeLists.txt b/runtime/src/iree/io/formats/safetensors/CMakeLists.txt
new file mode 100644
index 0000000..f52ac03
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/CMakeLists.txt
@@ -0,0 +1,42 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/io/formats/safetensors/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ safetensors
+ HDRS
+ "safetensors_format.h"
+ SRCS
+ "safetensors_format.c"
+ DEPS
+ iree::base
+ iree::io::file_handle
+ iree::io::parameter_index
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ safetensors_format_test
+ SRCS
+ "safetensors_format_test.cc"
+ DEPS
+ ::safetensors
+ iree::base::internal::file_io
+ iree::io::formats::safetensors::testdata::safetensors_files
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "requires-filesystem"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/safetensors/safetensors_format.c b/runtime/src/iree/io/formats/safetensors/safetensors_format.c
new file mode 100644
index 0000000..09ba861
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/safetensors_format.c
@@ -0,0 +1,520 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/formats/safetensors/safetensors_format.h"
+
+#include <ctype.h>
+
+// File format:
+// - uint64_t header_length;
+// - uint8_t header_json[header_length];
+// - uint8_t remaining_data[];
+//
+// JSON:
+// {
+// "TENSOR_NAME": {
+// "dtype": "F16",
+// "shape": [1, 16, 256],
+// "data_offsets": [BEGIN, END]
+// },
+// "NEXT_TENSOR_NAME": {...}
+// }
+//
+// The BEGIN offset is relative to the end of the header, not the file.
+// The END offset is oddly BEGIN+length such that length=END-BEGIN.
+//
+// Note that for such a fixed file format JSON is overkill and we don't want to
+// pull in a JSON parser just to get the data offsets. We parse the strings the
+// old fashioned way while wishing they were just numbers and bail if anything
+// looks suspect.
+//
+// Here's a real JSON blob from a test file (formatted to add whitespace):
+// <<8 byte header with a count of all bytes including trailing whitespace>>
+// {
+// "attention": {
+// "dtype": "F32",
+// "shape": [
+// 2,
+// 3
+// ],
+// "data_offsets": [
+// 0,
+// 24
+// ]
+// },
+// "embedding": {
+// "dtype": "F32",
+// "shape": [
+// 2,
+// 2
+// ],
+// "data_offsets": [
+// 24,
+// 40
+// ]
+// }
+// }
+// <<trailing whitespace>>
+// <<40 bytes of data>>
+//
+// Basic JSON spec (per json.org):
+// value:
+// object
+// array
+// string
+// number
+// "true"
+// "false"
+// "null"
+// object:
+// '{' ws '}'
+// '{' members '}'
+// members:
+// member
+// member ',' members
+// member:
+// ws string ws ':' element
+// array:
+// '[' ws ']'
+// '[' elements ']'
+// elements:
+// element
+// element ',' elements
+// element:
+// ws value ws
+// string:
+// '"' characters '"'
+// characters:
+// ""
+// character characters
+// character:
+// 0x0020 . 0x10FFFF - '"' - '\'
+// '\' escape
+// escape:
+// '"' '\' '/' 'b' 'f' 'n' 'r' 't'
+// 'u' hex hex hex hex
+// hex:
+// digit
+// 'A' . 'F'
+// 'a' . 'f'
+// number:
+// integer fraction exponent
+// ws:
+// ""
+// 0x0020 ws
+// 0x000A ws
+// 0x000D ws
+// 0x0009 ws
+
+static iree_status_t iree_json_consume_value(iree_string_view_t* str,
+ iree_string_view_t* out_value);
+
+// Consumes |keyword| from |str| and returns it, updating |str| to point
+// immediately after it.
+static iree_status_t iree_json_consume_keyword(iree_string_view_t* str,
+ iree_string_view_t keyword,
+ iree_string_view_t* out_value) {
+ if (iree_string_view_consume_prefix(str, keyword)) {
+ *out_value = keyword;
+ return iree_ok_status();
+ }
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "invalid keyword, expected '%.*s'", (int)keyword.size,
+ keyword.data);
+}
+
+// Consumes a number from |str| and returns as declared, updating |str| to
+// point immediately after the last character that could compose the number.
+// Assumes the input starts with `number` in the spec.
+static iree_status_t iree_json_consume_number(iree_string_view_t* str,
+ iree_string_view_t* out_value) {
+ // TODO(benvanik): support real numbers - for now we only handle integers
+ // because we are lazy. We scan for digits 0-9 until any non-digit is hit and
+ // then call it good.
+ iree_host_size_t break_pos = 0;
+ for (iree_host_size_t i = 0; i < str->size; ++i) {
+ if (!isdigit(str->data[i])) {
+ break_pos = i;
+ break;
+ }
+ }
+ if (break_pos == 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "invalid number");
+ }
+ *out_value = iree_string_view_substr(*str, 0, break_pos);
+ *str = iree_string_view_remove_prefix(*str, break_pos);
+ return iree_ok_status();
+}
+
+// Consumes a string from |str| and returns it unquoted, updating |str| to
+// point immediately after the trailing double quote of the string.
+// Assumes the input starts with `string` in the spec.
+static iree_status_t iree_json_consume_string(iree_string_view_t* str,
+ iree_string_view_t* out_value) {
+ *out_value = iree_string_view_empty();
+ if (!iree_string_view_starts_with(*str, IREE_SV("\""))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "missing string \" prefix");
+ }
+ iree_host_size_t start = 1;
+ iree_host_size_t end = 0;
+ for (iree_host_size_t i = start; i < str->size; ++i) {
+ char c = str->data[i];
+ if (c == '\"') {
+ // Unescaped quote is end of string.
+ end = i;
+ break;
+ } else if (c == '\\') {
+ if (i + 1 >= str->size) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "escape code with no contents");
+ }
+ // Escaped sequence - usually 1 but may be 4 for \uNNNN.
+ switch (str->data[++i]) {
+ case '\"':
+ case '\\':
+ case '/':
+ case 'b':
+ case 'f':
+ case 'n':
+ case 'r':
+ case 't':
+ break; // ok
+ case 'u':
+ // 'u' hex hex hex hex
+ if (i + 4 >= str->size) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "truncated unicode escape code");
+ }
+ i += 4;
+ break;
+ default:
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unrecognized string escape code %c", c);
+ }
+ }
+ }
+ *out_value = iree_string_view_substr(*str, start, end - start);
+ *str = iree_string_view_substr(*str, end + 1, IREE_HOST_SIZE_MAX);
+ return iree_ok_status();
+}
+
+// Consumes an object and all its descendents from |str| and returns it with
+// braces, updating |str| to point immediately after the trailing `}`.
+// Assumes the input starts with `object` in the spec.
+static iree_status_t iree_json_consume_object(iree_string_view_t* str,
+ iree_string_view_t* out_value) {
+ *out_value = iree_string_view_empty();
+ const char* start = str->data;
+ if (!iree_string_view_consume_prefix(str, IREE_SV("{"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing object {");
+ }
+ *str = iree_string_view_trim(*str);
+ while (!iree_string_view_is_empty(*str)) {
+ // Check for end of object.
+ if (iree_string_view_consume_prefix(str, IREE_SV("}"))) break;
+ // Try to parse key string.
+ iree_string_view_t key = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_consume_string(str, &key));
+ *str = iree_string_view_trim(*str);
+ // Expect : separator.
+ if (!iree_string_view_consume_prefix(str, IREE_SV(":"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "missing object member separator");
+ }
+ // Scan ahead to get the value span.
+ iree_string_view_t value = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_consume_value(str, &value));
+ // If there's a comma then we expect another value.
+ if (!iree_string_view_consume_prefix(str, IREE_SV(","))) break;
+ *str = iree_string_view_trim(*str);
+ }
+ if (!iree_string_view_consume_prefix(str, IREE_SV("}"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing object }");
+ }
+ const char* end = str->data;
+ *out_value = iree_make_string_view(start, end - start);
+ return iree_ok_status();
+}
+
+// Consumes an array and all its descendents from |str| and returns it with
+// brackets, updating |str| to point immediately after the trailing `]`.
+// Assumes the input starts with `array` in the spec.
+static iree_status_t iree_json_consume_array(iree_string_view_t* str,
+ iree_string_view_t* out_value) {
+ *out_value = iree_string_view_empty();
+ const char* start = str->data;
+ if (!iree_string_view_consume_prefix(str, IREE_SV("["))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing array [");
+ }
+ *str = iree_string_view_trim(*str);
+ while (!iree_string_view_is_empty(*str)) {
+ // Check for end of array.
+ if (iree_string_view_consume_prefix(str, IREE_SV("]"))) break;
+ // Get the array element.
+ iree_string_view_t value = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_consume_value(str, &value));
+ // If there's a comma then we expect another value.
+ if (!iree_string_view_consume_prefix(str, IREE_SV(","))) break;
+ *str = iree_string_view_trim(*str);
+ }
+ if (!iree_string_view_consume_prefix(str, IREE_SV("]"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing array ]");
+ }
+ const char* end = str->data;
+ *out_value = iree_make_string_view(start, end - start);
+ return iree_ok_status();
+}
+
+// Consumes a value from |str| and returns it, updating |str| to point
+// immediately after it.
+// Assumes the input starts with `value` in the spec.
+static iree_status_t iree_json_consume_value(iree_string_view_t* str,
+ iree_string_view_t* out_value) {
+ *out_value = iree_string_view_empty();
+ *str = iree_string_view_trim(*str);
+ if (str->size == 0) return iree_ok_status();
+ switch (str->data[0]) {
+ case '"':
+ return iree_json_consume_string(str, out_value);
+ case '{':
+ return iree_json_consume_object(str, out_value);
+ case '[':
+ return iree_json_consume_array(str, out_value);
+ case 't':
+ return iree_json_consume_keyword(str, IREE_SV("true"), out_value);
+ case 'f':
+ return iree_json_consume_keyword(str, IREE_SV("false"), out_value);
+ case 'n':
+ return iree_json_consume_keyword(str, IREE_SV("null"), out_value);
+ default:
+ return iree_json_consume_number(str, out_value);
+ }
+}
+
+typedef iree_status_t(IREE_API_PTR* iree_json_object_enumerator_fn_t)(
+ void* user_data, iree_string_view_t key, iree_string_view_t value);
+
+// Enumerates all key-value pairs in the given object |str|.
+// Assumes that the input matches `object` in the spec (`{` and `}` at edges).
+// |enumerator| can return IREE_STATUS_CANCELLED to skip all further entries.
+static iree_status_t iree_json_enumerate_object(
+ iree_string_view_t str, iree_json_object_enumerator_fn_t enumerator,
+ void* user_data) {
+ if (!iree_string_view_consume_prefix(&str, IREE_SV("{"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing object {");
+ }
+ str = iree_string_view_trim(str);
+ while (!iree_string_view_is_empty(str)) {
+ // Check for end of object.
+ if (iree_string_view_consume_prefix(&str, IREE_SV("}"))) break;
+ // Try to parse key string.
+ iree_string_view_t key = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_consume_string(&str, &key));
+ str = iree_string_view_trim(str);
+ // Expect : separator.
+ if (!iree_string_view_consume_prefix(&str, IREE_SV(":"))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "missing object member separator");
+ }
+ // Scan ahead to get the value span.
+ iree_string_view_t value = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_consume_value(&str, &value));
+ // Emit the key-value pair.
+ iree_status_t status = enumerator(user_data, key, value);
+ if (iree_status_is_cancelled(status)) {
+ iree_status_ignore(status);
+ break;
+ }
+ IREE_RETURN_IF_ERROR(status);
+ // If there's a comma then we expect another value.
+ if (!iree_string_view_consume_prefix(&str, IREE_SV(","))) break;
+ str = iree_string_view_trim(str);
+ }
+ return iree_ok_status();
+}
+
+typedef struct iree_json_lookup_object_value_state_t {
+ iree_string_view_t key;
+ iree_string_view_t* value;
+} iree_json_lookup_object_value_state_t;
+static iree_status_t iree_json_lookup_object_value_enumerator(
+ void* user_data, iree_string_view_t key, iree_string_view_t value) {
+ iree_json_lookup_object_value_state_t* state =
+ (iree_json_lookup_object_value_state_t*)user_data;
+ if (iree_string_view_equal(key, state->key)) {
+ *state->value = value;
+ return iree_status_from_code(IREE_STATUS_CANCELLED);
+ }
+ return iree_ok_status();
+}
+
+// Finds a directly nested |key| in the JSON |object_str| and returns its value.
+// Example:
+// {"foo": {"bar": true}, "taco": 51}
+// iree_json_lookup_object_value("foo") -> `{"bar": true}`
+// iree_json_lookup_object_value("taco") -> `51`
+static iree_status_t iree_json_lookup_object_value(
+ iree_string_view_t object_str, iree_string_view_t key,
+ iree_string_view_t* out_value) {
+ *out_value = iree_string_view_empty();
+ iree_json_lookup_object_value_state_t state = {
+ .key = key,
+ .value = out_value,
+ };
+ return iree_json_enumerate_object(
+ object_str, iree_json_lookup_object_value_enumerator, &state);
+}
+
+// Parses a `[begin, end]` JSON array.
+static bool iree_io_parse_json_data_offsets(iree_string_view_t data_offsets_str,
+ uint64_t* out_begin,
+ uint64_t* out_end) {
+ if (!iree_string_view_consume_prefix(&data_offsets_str, IREE_SV("[")) ||
+ !iree_string_view_consume_suffix(&data_offsets_str, IREE_SV("]"))) {
+ return false;
+ }
+ iree_string_view_t begin_str = iree_string_view_empty();
+ iree_string_view_t end_str = iree_string_view_empty();
+ if (iree_string_view_split(iree_string_view_trim(data_offsets_str), ',',
+ &begin_str, &end_str) == -1) {
+ return false;
+ }
+ return iree_string_view_atoi_uint64(iree_string_view_trim(begin_str),
+ out_begin) &&
+ iree_string_view_atoi_uint64(iree_string_view_trim(end_str), out_end);
+}
+
+typedef struct iree_io_enumerate_safetensors_entry_state_t {
+ iree_io_file_handle_t* file_handle;
+ uint64_t base_offset;
+ uint64_t data_size;
+ iree_io_parameter_index_t* index;
+} iree_io_enumerate_safetensors_entry_state_t;
+
+// Enumerates the outer safetensors header JSON object and emits entries to the
+// |index|. |key| will be the tensor name (what we call a parameter key) and
+// |value| will be the entry object we'll need to extract info from.
+//
+// Each entry in the dictionary looks something like this, note the order of
+// the fields is undefined and there may be some we ignore:
+// "TENSOR_NAME": {
+// "dtype": "F16",
+// "shape": [1, 16, 256],
+// "data_offsets": [BEGIN, END]
+// }, <-- optional (omitted at end)
+static iree_status_t iree_io_enumerate_safetensors_entries(
+ void* user_data, iree_string_view_t key, iree_string_view_t value) {
+ iree_io_enumerate_safetensors_entry_state_t* entry_state =
+ (iree_io_enumerate_safetensors_entry_state_t*)user_data;
+
+ // Ignore special "__metadata__" entry. We ignore it for now.
+ if (iree_string_view_equal(key, IREE_SV("__metadata__"))) {
+ return iree_ok_status();
+ }
+
+ // Lookup the data offsets array.
+ iree_string_view_t data_offsets_str = iree_string_view_empty();
+ IREE_RETURN_IF_ERROR(iree_json_lookup_object_value(
+ value, IREE_SV("data_offsets"), &data_offsets_str));
+
+ // Extract the data offsets from the array and verify they are in range.
+ uint64_t begin = 0;
+ uint64_t end = 0;
+ if (!iree_io_parse_json_data_offsets(data_offsets_str, &begin, &end)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "failed to parse entry data offsets `%.*s`",
+ (int)data_offsets_str.size, data_offsets_str.data);
+ }
+ if (begin > end || end > entry_state->data_size) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "entry has data offsets outside of the "
+ "available data (begin=%" PRIu64 ", end=%" PRIu64
+ ", available=%" PRIu64 ")",
+ begin, end, entry_state->data_size);
+ }
+
+ // Add entry to the index.
+ iree_io_parameter_index_entry_t entry = {
+ .key = key,
+ .metadata = iree_const_byte_span_empty(),
+ .file_handle = entry_state->file_handle,
+ .offset = entry_state->base_offset + begin,
+ .length = end - begin,
+ };
+ return iree_io_parameter_index_add(entry_state->index, &entry);
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parse_safetensors_index_from_memory(
+ iree_io_file_handle_t* file_handle, iree_const_byte_span_t file_contents,
+ iree_io_parameter_index_t* index) {
+ // Reads the header JSON blob out of the file contents and calculates the base
+ // offset that all data ranges are relative to. Verifies that the header and
+ // base offset is in range but each entry data range still needs to be
+ // verified.
+ uint64_t remaining_bytes = file_contents.data_length;
+ uint64_t header_length = 0;
+ if (remaining_bytes < sizeof(header_length)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "insufficient capacity for safetensors header "
+ "length (need at least %" PRIhsz
+ " bytes but have %" PRIu64 ")",
+ sizeof(header_length), remaining_bytes);
+ }
+ header_length =
+ iree_unaligned_load_le_u64((const uint64_t*)file_contents.data);
+ remaining_bytes -= sizeof(header_length);
+ if (remaining_bytes < header_length) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "insufficient capacity for safetensors header "
+ "contents (declared as %" PRIu64
+ " but only %" PRIu64 " bytes available)",
+ header_length, remaining_bytes);
+ }
+ const iree_string_view_t header_json = iree_make_string_view(
+ (const char*)file_contents.data + sizeof(header_length),
+ (iree_host_size_t)header_length);
+ const uint64_t base_offset = sizeof(header_length) + header_length;
+ remaining_bytes -= header_length;
+
+ // Parses a safetensors |header_json| blob and emits entries to |index|.
+ // Each entry is bounds checked against the |data_size| of the file (bytes
+ // excluding the header, relative to |base_offset|).
+ iree_io_enumerate_safetensors_entry_state_t enumerate_state = {
+ .file_handle = file_handle,
+ .base_offset = base_offset,
+ .data_size = remaining_bytes,
+ .index = index,
+ };
+ return iree_json_enumerate_object(
+ header_json, iree_io_enumerate_safetensors_entries, &enumerate_state);
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parse_safetensors_index(
+ iree_io_file_handle_t* file_handle, iree_io_parameter_index_t* index) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Today we only support memory files.
+ // TODO(benvanik): support iree_io_stream_t wrapping for parsing the index.
+ if (iree_io_file_handle_type(file_handle) !=
+ IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "non-memory safetensors files not yet supported");
+ }
+ iree_byte_span_t host_allocation =
+ iree_io_file_handle_primitive(file_handle).value.host_allocation;
+
+ iree_status_t status = iree_io_parse_safetensors_index_from_memory(
+ file_handle,
+ iree_make_const_byte_span(host_allocation.data,
+ host_allocation.data_length),
+ index);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/io/formats/safetensors/safetensors_format.h b/runtime/src/iree/io/formats/safetensors/safetensors_format.h
new file mode 100644
index 0000000..30a6638
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/safetensors_format.h
@@ -0,0 +1,41 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_FORMATS_SAFETENSORS_SAFETENSORS_FORMAT_H_
+#define IREE_IO_FORMATS_SAFETENSORS_SAFETENSORS_FORMAT_H_
+
+#include "iree/base/api.h"
+#include "iree/io/file_handle.h"
+#include "iree/io/parameter_index.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Parses a .safetensors file and merges its contained resources into |index|.
+//
+// Documentation: https://github.com/huggingface/safetensors
+// This is a very basic archive file with some issues (no alignment, etc) but
+// at least doesn't require Python pickle decoding (just JSON). The major reason
+// to use this is if sourcing from a Hugging Face model that has its weights
+// already in the safetensors format.
+//
+// WARNING: this implementation has not been thoroughly tested or verified as
+// safe or correct. Use with caution only on trusted inputs. Tip: don't embed
+// other file formats within your file format and call it "safe" as it's only
+// going to be as safe as the implementations of the other file formats you
+// embed. In this case a full JSON parser is required and must be safe and we
+// don't take that dependency for a testing tool. Users wanting to productionize
+// this should implement their own safetensors parser or use the rust one with
+// all the fun that entails.
+IREE_API_EXPORT iree_status_t iree_io_parse_safetensors_index(
+ iree_io_file_handle_t* file_handle, iree_io_parameter_index_t* index);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_FORMATS_SAFETENSORS_SAFETENSORS_FORMAT_H_
diff --git a/runtime/src/iree/io/formats/safetensors/safetensors_format_test.cc b/runtime/src/iree/io/formats/safetensors/safetensors_format_test.cc
new file mode 100644
index 0000000..dc77f6d
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/safetensors_format_test.cc
@@ -0,0 +1,105 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/formats/safetensors/safetensors_format.h"
+
+#include "iree/base/internal/file_io.h"
+#include "iree/io/formats/safetensors/testdata/safetensors_files.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace {
+
+static iree_io_file_handle_t* OpenTestFile(const char* name) {
+ const struct iree_file_toc_t* file_toc = iree_io_safetensors_files_create();
+ for (size_t i = 0; i < iree_io_safetensors_files_size(); ++i) {
+ if (strcmp(file_toc[i].name, name) == 0) {
+ iree_io_file_handle_t* file_handle = NULL;
+ IREE_CHECK_OK(iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_READ,
+ iree_make_byte_span((void*)file_toc[i].data, file_toc[i].size),
+ iree_io_file_handle_release_callback_null(), iree_allocator_system(),
+ &file_handle));
+ return file_handle;
+ }
+ }
+ IREE_CHECK_OK(iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "test file `%s` not found embedded into test binary", name));
+ return NULL;
+}
+
+TEST(SafetensorsFormatTest, Empty) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("empty.safetensors");
+ IREE_ASSERT_OK(iree_io_parse_safetensors_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ iree_io_parameter_index_release(index);
+}
+
+TEST(SafetensorsFormatTest, SingleTensor) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("single.safetensors");
+ IREE_ASSERT_OK(iree_io_parse_safetensors_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ const iree_io_parameter_index_entry_t* entry0 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor0"), &entry0));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor0"), entry0->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry0->metadata));
+ EXPECT_EQ(entry0->offset, 72);
+ EXPECT_EQ(entry0->length, 16);
+
+ iree_io_parameter_index_release(index);
+}
+
+TEST(SafetensorsFormatTest, MultipleTensors) {
+ iree_io_parameter_index_t* index = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_create(iree_allocator_system(), &index));
+
+ iree_io_file_handle_t* file_handle = OpenTestFile("multiple.safetensors");
+ IREE_ASSERT_OK(iree_io_parse_safetensors_index(file_handle, index));
+ iree_io_file_handle_release(file_handle);
+
+ const iree_io_parameter_index_entry_t* entry0 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor0"), &entry0));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor0"), entry0->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry0->metadata));
+ EXPECT_EQ(entry0->offset, 200);
+ EXPECT_EQ(entry0->length, 16);
+
+ const iree_io_parameter_index_entry_t* entry1 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor1"), &entry1));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor1"), entry1->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry1->metadata));
+ EXPECT_EQ(entry1->offset, 216);
+ EXPECT_EQ(entry1->length, 8);
+
+ const iree_io_parameter_index_entry_t* entry2 = NULL;
+ IREE_ASSERT_OK(
+ iree_io_parameter_index_lookup(index, IREE_SV("tensor2"), &entry2));
+ EXPECT_TRUE(iree_string_view_equal(IREE_SV("tensor2"), entry2->key));
+ EXPECT_TRUE(iree_const_byte_span_is_empty(entry2->metadata));
+ EXPECT_EQ(entry2->offset, 224);
+ EXPECT_EQ(entry2->length, 48);
+
+ iree_io_parameter_index_release(index);
+}
+
+} // namespace
+} // namespace iree
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/BUILD.bazel b/runtime/src/iree/io/formats/safetensors/testdata/BUILD.bazel
new file mode 100644
index 0000000..75e2fd6
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/BUILD.bazel
@@ -0,0 +1,27 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+c_embed_data(
+ name = "safetensors_files",
+ testonly = True,
+ srcs = [
+ "empty.safetensors",
+ "multiple.safetensors",
+ "single.safetensors",
+ ],
+ c_file_output = "safetensors_files.c",
+ flatten = True,
+ h_file_output = "safetensors_files.h",
+ identifier = "iree_io_safetensors_files",
+)
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/CMakeLists.txt b/runtime/src/iree/io/formats/safetensors/testdata/CMakeLists.txt
new file mode 100644
index 0000000..4c3e0cf
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/CMakeLists.txt
@@ -0,0 +1,31 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/io/formats/safetensors/testdata/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_c_embed_data(
+ NAME
+ safetensors_files
+ SRCS
+ "empty.safetensors"
+ "multiple.safetensors"
+ "single.safetensors"
+ C_FILE_OUTPUT
+ "safetensors_files.c"
+ H_FILE_OUTPUT
+ "safetensors_files.h"
+ IDENTIFIER
+ "iree_io_safetensors_files"
+ TESTONLY
+ FLATTEN
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/empty.safetensors b/runtime/src/iree/io/formats/safetensors/testdata/empty.safetensors
new file mode 100644
index 0000000..3969499
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/empty.safetensors
Binary files differ
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/generate_safetensors_files.py b/runtime/src/iree/io/formats/safetensors/testdata/generate_safetensors_files.py
new file mode 100644
index 0000000..7d19a21
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/generate_safetensors_files.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# https://huggingface.co/docs/safetensors/index
+#
+# To regenerate:
+# $ pip install torch safetensors
+# $ cd runtime/src/iree/io/formats/safetensors/testdata/
+# $ ./generate_safetensors_files.py
+
+import torch
+from safetensors.torch import save_file
+
+# no tensors
+save_file({}, "empty.safetensors")
+
+# single tensor
+save_file({"tensor0": torch.zeros((2, 2))}, "single.safetensors")
+
+# multiple tensors
+save_file(
+ {
+ "tensor0": torch.zeros((2, 2)),
+ "tensor1": torch.zeros((1, 2)),
+ "tensor2": torch.zeros((4, 3)),
+ },
+ "multiple.safetensors",
+)
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/multiple.safetensors b/runtime/src/iree/io/formats/safetensors/testdata/multiple.safetensors
new file mode 100644
index 0000000..0d51b6b
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/multiple.safetensors
Binary files differ
diff --git a/runtime/src/iree/io/formats/safetensors/testdata/single.safetensors b/runtime/src/iree/io/formats/safetensors/testdata/single.safetensors
new file mode 100644
index 0000000..f28c492
--- /dev/null
+++ b/runtime/src/iree/io/formats/safetensors/testdata/single.safetensors
Binary files differ
diff --git a/runtime/src/iree/io/parameter_index.c b/runtime/src/iree/io/parameter_index.c
new file mode 100644
index 0000000..2b38ed2
--- /dev/null
+++ b/runtime/src/iree/io/parameter_index.c
@@ -0,0 +1,230 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/parameter_index.h"
+
+#include "iree/base/internal/atomics.h"
+#include "iree/base/internal/synchronization.h"
+
+struct iree_io_parameter_index_t {
+ iree_atomic_ref_count_t ref_count;
+ iree_allocator_t host_allocator;
+
+ // Guards mutation of the entries list.
+ // NOTE: this does not guard the entries themselves as we assume they are
+ // immutable (today).
+ iree_slim_mutex_t mutex;
+
+ // Total capacity of the entries list in elements.
+ iree_host_size_t entry_capacity;
+ // Currently used entry count in elements.
+ iree_host_size_t entry_count;
+ // Dense list of entries in the index. Grows as needed.
+ iree_io_parameter_index_entry_t** entries;
+};
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_create(
+ iree_allocator_t host_allocator, iree_io_parameter_index_t** out_index) {
+ IREE_ASSERT_ARGUMENT(out_index);
+ *out_index = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_parameter_index_t* index = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_allocator_malloc(host_allocator, sizeof(*index), (void**)&index));
+ iree_atomic_ref_count_init(&index->ref_count);
+ index->host_allocator = host_allocator;
+
+ iree_slim_mutex_initialize(&index->mutex);
+
+ // Grown on first use. We could allocate a bit of inline storage or take an
+ // optional initial capacity for callers that know.
+ index->entry_capacity = 0;
+ index->entry_count = 0;
+ index->entries = NULL;
+
+ *out_index = index;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_io_parameter_index_destroy(iree_io_parameter_index_t* index) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator = index->host_allocator;
+
+ for (iree_host_size_t i = 0; i < index->entry_count; ++i) {
+ iree_io_parameter_index_entry_t* entry = index->entries[i];
+ iree_io_file_handle_release(entry->file_handle);
+ iree_allocator_free(host_allocator, entry);
+ }
+ if (index->entries) {
+ iree_allocator_free(host_allocator, index->entries);
+ }
+
+ iree_slim_mutex_deinitialize(&index->mutex);
+
+ iree_allocator_free(host_allocator, index);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT void iree_io_parameter_index_retain(
+ iree_io_parameter_index_t* index) {
+ if (IREE_LIKELY(index)) {
+ iree_atomic_ref_count_inc(&index->ref_count);
+ }
+}
+
+IREE_API_EXPORT void iree_io_parameter_index_release(
+ iree_io_parameter_index_t* index) {
+ if (IREE_LIKELY(index) && iree_atomic_ref_count_dec(&index->ref_count) == 1) {
+ iree_io_parameter_index_destroy(index);
+ }
+}
+
+IREE_API_EXPORT iree_host_size_t
+iree_io_parameter_index_count(iree_io_parameter_index_t* index) {
+ IREE_ASSERT_ARGUMENT(index);
+ iree_slim_mutex_lock(&index->mutex);
+ iree_host_size_t count = index->entry_count;
+ iree_slim_mutex_unlock(&index->mutex);
+ return count;
+}
+
+static iree_status_t iree_io_parameter_index_reserve_unsafe(
+ iree_io_parameter_index_t* index, iree_host_size_t new_capacity) {
+ IREE_ASSERT_ARGUMENT(index);
+ if (new_capacity < index->entry_capacity) return iree_ok_status();
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, new_capacity);
+
+ iree_io_parameter_index_entry_t** new_entries = index->entries;
+ iree_status_t status = iree_allocator_realloc(
+ index->host_allocator, new_capacity * sizeof(index->entries[0]),
+ (void**)&new_entries);
+ if (iree_status_is_ok(status)) {
+ index->entry_capacity = new_capacity;
+ index->entries = new_entries;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_reserve(
+ iree_io_parameter_index_t* index, iree_host_size_t new_capacity) {
+ IREE_ASSERT_ARGUMENT(index);
+ iree_slim_mutex_lock(&index->mutex);
+ iree_status_t status =
+ iree_io_parameter_index_reserve_unsafe(index, new_capacity);
+ iree_slim_mutex_unlock(&index->mutex);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_io_parameter_index_add(iree_io_parameter_index_t* index,
+ const iree_io_parameter_index_entry_t* entry) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_ASSERT_ARGUMENT(entry);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, entry->key.data, entry->key.size);
+ iree_slim_mutex_lock(&index->mutex);
+
+ // Grow the index if needed (double each time after some initial minimum).
+ iree_status_t status = iree_ok_status();
+ if (index->entry_count == index->entry_capacity) {
+ status = iree_io_parameter_index_reserve_unsafe(
+ index, iree_max(16, index->entry_capacity * 2));
+ }
+
+ // Clone the entry memory. We allocate it as a single slab and stash the
+ // pointers for easier access by callers. Entries themselves are never
+ // reallocated so the pointers are safe to embed.
+ iree_io_parameter_index_entry_t* cloned_entry = NULL;
+ if (iree_status_is_ok(status)) {
+ iree_host_size_t total_size =
+ sizeof(*cloned_entry) + entry->key.size + entry->metadata.data_length;
+ status = iree_allocator_malloc(index->host_allocator, total_size,
+ (void**)&cloned_entry);
+ }
+ if (iree_status_is_ok(status)) {
+ cloned_entry->key = iree_make_string_view(
+ (char*)cloned_entry + sizeof(*cloned_entry), entry->key.size);
+ cloned_entry->metadata =
+ iree_const_byte_span_is_empty(entry->metadata)
+ ? iree_const_byte_span_empty()
+ : iree_make_const_byte_span(
+ (uint8_t*)cloned_entry->key.data + cloned_entry->key.size,
+ entry->metadata.data_length);
+ cloned_entry->file_handle = entry->file_handle;
+ iree_io_file_handle_retain(cloned_entry->file_handle);
+ cloned_entry->offset = entry->offset;
+ cloned_entry->length = entry->length;
+ memcpy((void*)cloned_entry->key.data, entry->key.data, entry->key.size);
+ memcpy((void*)cloned_entry->metadata.data, entry->metadata.data,
+ entry->metadata.data_length);
+
+ // Append the entry to the file index.
+ index->entries[index->entry_count++] = cloned_entry;
+ }
+
+ iree_slim_mutex_unlock(&index->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_get(
+ iree_io_parameter_index_t* index, iree_host_size_t i,
+ const iree_io_parameter_index_entry_t** out_entry) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_ASSERT_ARGUMENT(out_entry);
+ *out_entry = NULL;
+ iree_slim_mutex_lock(&index->mutex);
+
+ iree_status_t status = iree_ok_status();
+ if (i < index->entry_count) {
+ *out_entry = index->entries[i];
+ } else {
+ status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "entry %" PRIhsz " out of range (have %" PRIhsz
+ " entries in the index)",
+ i, index->entry_count);
+ }
+
+ iree_slim_mutex_unlock(&index->mutex);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_lookup(
+ iree_io_parameter_index_t* index, iree_string_view_t key,
+ const iree_io_parameter_index_entry_t** out_entry) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_ASSERT_ARGUMENT(out_entry);
+ *out_entry = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, key.data, key.size);
+ iree_slim_mutex_lock(&index->mutex);
+
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < index->entry_count; ++i) {
+ const iree_io_parameter_index_entry_t* entry = index->entries[i];
+ if (iree_string_view_equal(key, entry->key)) {
+ *out_entry = entry;
+ break;
+ }
+ }
+ if (*out_entry == NULL) {
+ status = iree_make_status(IREE_STATUS_NOT_FOUND,
+ "no parameter found in index with key '%.*s'",
+ (int)key.size, key.data);
+ }
+
+ iree_slim_mutex_unlock(&index->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/io/parameter_index.h b/runtime/src/iree/io/parameter_index.h
new file mode 100644
index 0000000..24b467a
--- /dev/null
+++ b/runtime/src/iree/io/parameter_index.h
@@ -0,0 +1,88 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_PARAMETER_INDEX_H_
+#define IREE_IO_PARAMETER_INDEX_H_
+
+#include "iree/base/api.h"
+#include "iree/io/file_handle.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// An entry in an in-memory file index.
+typedef struct iree_io_parameter_index_entry_t {
+ // Key used to reference this file.
+ iree_string_view_t key;
+ // Optional metadata.
+ iree_const_byte_span_t metadata;
+ // File handle backing this entry, retained.
+ iree_io_file_handle_t* file_handle;
+ // Offset of the entry in bytes relative to the base file offset.
+ uint64_t offset;
+ // Length of the entry in bytes.
+ uint64_t length;
+} iree_io_parameter_index_entry_t;
+
+// An in-memory file index mapping keys to byte ranges in referenced files.
+// A single index may contain entries from multiple files. Each parameter is
+// backed by a contiguous range in a single file.
+//
+// Thread-safe due to insert-only behavior. If we ever wanted to allow removal
+// from the index we would need to change callers to hold a mutex or design
+// a callback-based API to ensure that entries were live for as long as the
+// callers were using them.
+typedef struct iree_io_parameter_index_t iree_io_parameter_index_t;
+
+// Creates an empty file index.
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_create(
+ iree_allocator_t host_allocator, iree_io_parameter_index_t** out_index);
+
+// Retains the given |index| for the caller.
+IREE_API_EXPORT void iree_io_parameter_index_retain(
+ iree_io_parameter_index_t* index);
+
+// Releases the given |index| from the caller.
+IREE_API_EXPORT void iree_io_parameter_index_release(
+ iree_io_parameter_index_t* index);
+
+// Returns the number of entries in the index at the time the method is called.
+// New entries may be added by other threads between when the value is queried
+// and when the caller enumerates entries. Use this only for debugging.
+IREE_API_EXPORT iree_host_size_t
+iree_io_parameter_index_count(iree_io_parameter_index_t* index);
+
+// Reserves storage for at least |new_capacity| entries in the index.
+// Ignored if storage capacity is already sufficient.
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_reserve(
+ iree_io_parameter_index_t* index, iree_host_size_t new_capacity);
+
+// Adds a new entry to the file index.
+// The string key and optional metadata will be copied into the index and
+// need not remain valid after the call returns. Referenced file handles will
+// be retained for the lifetime of the index.
+IREE_API_EXPORT iree_status_t
+iree_io_parameter_index_add(iree_io_parameter_index_t* index,
+ const iree_io_parameter_index_entry_t* entry);
+
+// Returns the entry at index |i| in [0, iree_io_parameter_index_count).
+// The returned |out_entry| is valid for the lifetime of the index.
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_get(
+ iree_io_parameter_index_t* index, iree_host_size_t i,
+ const iree_io_parameter_index_entry_t** out_entry);
+
+// Performs a file entry lookup of |key| in the index and returns it.
+// The returned |out_entry| is valid for the lifetime of the index.
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_lookup(
+ iree_io_parameter_index_t* index, iree_string_view_t key,
+ const iree_io_parameter_index_entry_t** out_entry);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_PARAMETER_INDEX_H_
diff --git a/runtime/src/iree/io/parameter_index_provider.c b/runtime/src/iree/io/parameter_index_provider.c
new file mode 100644
index 0000000..6f861bc
--- /dev/null
+++ b/runtime/src/iree/io/parameter_index_provider.c
@@ -0,0 +1,637 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/parameter_index_provider.h"
+
+#include "iree/hal/utils/file_cache.h"
+
+// Limit concurrent operations to avoid blowing the stack. This is arbitrary and
+// if we wanted to support more we could switch to using heap allocations or
+// a growable stack scratchpad.
+#define IREE_IO_PARAMETER_INDEX_PROVIDER_CONCURRENT_OPERATION_LIMIT 128
+
+typedef struct iree_io_parameter_index_provider_t {
+ iree_io_parameter_provider_t base;
+ iree_allocator_t host_allocator;
+ iree_host_size_t max_concurrent_operations;
+ iree_string_view_t scope;
+ iree_io_parameter_index_t* index;
+ iree_hal_file_cache_t* file_cache;
+} iree_io_parameter_index_provider_t;
+
+static const iree_io_parameter_provider_vtable_t
+ iree_io_parameter_index_provider_vtable;
+
+static iree_io_parameter_index_provider_t*
+iree_io_parameter_index_provider_cast(
+ iree_io_parameter_provider_t* IREE_RESTRICT base_provider) {
+ return (iree_io_parameter_index_provider_t*)base_provider;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_provider_create(
+ iree_string_view_t scope, iree_io_parameter_index_t* index,
+ iree_host_size_t max_concurrent_operations, iree_allocator_t host_allocator,
+ iree_io_parameter_provider_t** out_provider) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_ASSERT_ARGUMENT(out_provider);
+ *out_provider = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, scope.data, scope.size);
+
+ max_concurrent_operations =
+ iree_min(max_concurrent_operations,
+ IREE_IO_PARAMETER_INDEX_PROVIDER_CONCURRENT_OPERATION_LIMIT);
+
+ iree_io_parameter_index_provider_t* provider = NULL;
+ iree_host_size_t total_size = sizeof(*provider) + scope.size;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, total_size, (void**)&provider));
+ iree_atomic_ref_count_init(&provider->base.ref_count);
+ provider->base.vtable = &iree_io_parameter_index_provider_vtable;
+ provider->host_allocator = host_allocator;
+ provider->max_concurrent_operations = max_concurrent_operations;
+
+ provider->scope = iree_make_string_view(
+ (const char*)provider + sizeof(*provider), scope.size);
+ memcpy((void*)provider->scope.data, scope.data, scope.size);
+
+ provider->index = index;
+ iree_io_parameter_index_retain(index);
+
+ iree_status_t status =
+ iree_hal_file_cache_create(host_allocator, &provider->file_cache);
+
+ if (iree_status_is_ok(status)) {
+ *out_provider = (iree_io_parameter_provider_t*)provider;
+ } else {
+ iree_io_parameter_provider_release((iree_io_parameter_provider_t*)provider);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_io_parameter_index_provider_destroy(
+ iree_io_parameter_provider_t* IREE_RESTRICT base_provider) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ iree_allocator_t host_allocator = provider->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_file_cache_release(provider->file_cache);
+ iree_io_parameter_index_release(provider->index);
+
+ iree_allocator_free(host_allocator, provider);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_io_parameter_index_provider_notify(
+ iree_io_parameter_provider_t* base_provider,
+ iree_io_parameter_provider_signal_t signal) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ switch (signal) {
+ case IREE_IO_PARAMETER_PROVIDER_SIGNAL_SUSPEND:
+ case IREE_IO_PARAMETER_PROVIDER_SIGNAL_LOW_MEMORY:
+ iree_hal_file_cache_trim(provider->file_cache);
+ break;
+ default:
+ break;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static bool iree_io_parameter_index_provider_query_support(
+ iree_io_parameter_provider_t* base_provider, iree_string_view_t scope) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ return iree_string_view_equal(scope, provider->scope);
+}
+
+// Resolves a parameter with |key| for use on the given |device|.
+// Returns the entry containing the parameter metadata and a retained
+// HAL file that stores it (must be released by the caller).
+static iree_status_t iree_io_parameter_index_provider_resolve(
+ iree_io_parameter_index_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity, iree_string_view_t scope,
+ iree_string_view_t key, iree_hal_memory_access_t access,
+ const iree_io_parameter_index_entry_t** out_entry,
+ iree_hal_file_t** out_file) {
+ IREE_ASSERT_ARGUMENT(out_entry);
+ IREE_ASSERT_ARGUMENT(out_file);
+ *out_entry = NULL;
+ *out_file = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Lookup the parameter in the index.
+ const iree_io_parameter_index_entry_t* entry = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_index_lookup(provider->index, key, &entry));
+
+ // Get (or import) the HAL file backing the entry.
+ // NOTE: file is retained!
+ iree_hal_file_t* file = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_file_cache_lookup(
+ provider->file_cache, device, queue_affinity, access,
+ entry->file_handle, IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE, &file));
+
+ *out_entry = entry;
+ *out_file = file;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Validates that the range specified by [offset, offset+length) is in bounds.
+static iree_status_t iree_io_validate_parameter_range(
+ iree_hal_memory_access_t required_access,
+ const iree_io_parameter_index_entry_t* entry, uint64_t offset,
+ uint64_t length) {
+ iree_hal_memory_access_t allowed_access = IREE_HAL_MEMORY_ACCESS_NONE;
+ if (iree_all_bits_set(iree_io_file_handle_access(entry->file_handle),
+ IREE_IO_FILE_ACCESS_READ)) {
+ allowed_access |= IREE_HAL_MEMORY_ACCESS_READ;
+ }
+ if (iree_all_bits_set(iree_io_file_handle_access(entry->file_handle),
+ IREE_IO_FILE_ACCESS_WRITE)) {
+ allowed_access |=
+ IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD;
+ }
+ if (!iree_all_bits_set(allowed_access, required_access)) {
+ return iree_make_status(IREE_STATUS_PERMISSION_DENIED,
+ "access denied to parameter backing file");
+#if IREE_STATUS_MODE
+ iree_bitfield_string_temp_t temp0, temp1;
+ iree_string_view_t allowed_memory_access_str =
+ iree_hal_memory_access_format(allowed_access, &temp0);
+ iree_string_view_t required_memory_access_str =
+ iree_hal_memory_access_format(required_access, &temp1);
+ return iree_make_status(
+ IREE_STATUS_PERMISSION_DENIED,
+ "parameter storage does not support the requested access "
+ "type; parameter allows %.*s, operation requires %.*s",
+ (int)allowed_memory_access_str.size, allowed_memory_access_str.data,
+ (int)required_memory_access_str.size, required_memory_access_str.data);
+#else
+ return iree_status_from_code(IREE_STATUS_PERMISSION_DENIED);
+#endif // IREE_STATUS_MODE
+ }
+
+ if (offset + length > entry->length) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "parameter range out of bounds (offset=%" PRIu64
+ ", length=%" PRIu64 ", size=%" PRIu64 ")",
+ offset, length, entry->length);
+ }
+
+ return iree_ok_status();
+}
+
+static void iree_io_file_handle_buffer_release(void* user_data,
+ iree_hal_buffer_t* buffer) {
+ iree_io_file_handle_release((iree_io_file_handle_t*)user_data);
+}
+
+static iree_status_t iree_io_parameter_index_provider_load(
+ iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_params_t target_params,
+ iree_device_size_t length,
+ iree_hal_buffer_t** IREE_RESTRICT out_target_buffer) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Lookup the parameter metadata and get its backing file.
+ const iree_io_parameter_index_entry_t* source_entry = NULL;
+ iree_hal_file_t* source_file = NULL; // retained
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_index_provider_resolve(
+ provider, device, queue_affinity, source_scope, source_key,
+ target_params.access, &source_entry, &source_file));
+
+ // Validate the parameter range is in-bounds.
+ iree_status_t status = iree_io_validate_parameter_range(
+ target_params.access, source_entry, source_offset, length);
+
+ // Try first to reuse the file backing store directly as a buffer. This only
+ // works with specific file types and with specific target usage. The most
+ // common cases for this are when using parameters as staging sources (so host
+ // memory is ok) or on unified memory systems (where host memory is device
+ // memory) and the file was originally mapped. We could extend the conditions
+ // in which we use this with some better file handle helpers that allow us to
+ // map files that we already have open via other mechanisms (FILE, fd, etc).
+ iree_hal_buffer_t* target_buffer = NULL;
+ if (iree_status_is_ok(status) &&
+ iree_io_file_handle_type(source_entry->file_handle) ==
+ IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ iree_byte_span_t host_allocation =
+ iree_io_file_handle_primitive(source_entry->file_handle)
+ .value.host_allocation;
+ iree_hal_external_buffer_t external_buffer = {
+ .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION,
+ .flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE,
+ .size = host_allocation.data_length,
+ .handle =
+ {
+ .host_allocation =
+ {
+ .ptr = host_allocation.data,
+ },
+ },
+ };
+ iree_hal_buffer_release_callback_t release_callback = {
+ .fn = iree_io_file_handle_buffer_release,
+ .user_data = source_entry->file_handle,
+ };
+ iree_io_file_handle_retain(source_entry->file_handle);
+ iree_status_t import_status = iree_hal_allocator_import_buffer(
+ iree_hal_device_allocator(device), target_params, &external_buffer,
+ release_callback, &target_buffer);
+ if (iree_status_is_ok(import_status)) {
+ // Import succeeded - issue a barrier to preserve the async timeline.
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "import succeeded");
+ status = iree_hal_device_queue_barrier(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list);
+ } else {
+ // Failed to import - that's ok as we'll just do the full allocate + read.
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "import failed");
+ import_status = iree_status_ignore(import_status);
+ iree_io_file_handle_release(source_entry->file_handle);
+ }
+ }
+
+ if (!target_buffer) {
+ // Temporary semaphore for chaining the allocation and read.
+ iree_hal_semaphore_t* temporary_semaphore = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_semaphore_create(device, 0ull, &temporary_semaphore);
+ }
+ uint64_t temporary_semaphore_value = 1ull;
+ const iree_hal_semaphore_list_t alloca_semaphore_list = {
+ .count = 1,
+ .semaphores = &temporary_semaphore,
+ .payload_values = &temporary_semaphore_value,
+ };
+
+ // Allocate the target buffer.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_alloca(
+ device, queue_affinity, wait_semaphore_list, alloca_semaphore_list,
+ IREE_HAL_ALLOCATOR_POOL_DEFAULT, target_params, length,
+ &target_buffer);
+ }
+
+ // Queue the file read into the target buffer.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_read(
+ device, queue_affinity, alloca_semaphore_list, signal_semaphore_list,
+ source_file, source_entry->offset + source_offset, target_buffer, 0,
+ length, 0);
+ }
+
+ iree_hal_semaphore_release(temporary_semaphore);
+ }
+
+ iree_hal_file_release(source_file);
+ if (iree_status_is_ok(status)) {
+ IREE_ASSERT_NE(target_buffer, NULL);
+ *out_target_buffer = target_buffer;
+ } else {
+ iree_hal_buffer_release(target_buffer);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_io_parameter_index_provider_read(
+ iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Lookup the parameter metadata and get its backing file.
+ const iree_io_parameter_index_entry_t* source_entry = NULL;
+ iree_hal_file_t* source_file = NULL; // retained
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_index_provider_resolve(
+ provider, device, queue_affinity, source_scope, source_key,
+ IREE_HAL_MEMORY_ACCESS_READ, &source_entry, &source_file));
+
+ // Validate the parameter range is in-bounds.
+ iree_status_t status = iree_io_validate_parameter_range(
+ IREE_HAL_MEMORY_ACCESS_READ, source_entry, source_offset, length);
+
+ // Queue the file read into the target buffer.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_read(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+ source_file, source_entry->offset + source_offset, target_buffer,
+ target_offset, length, 0);
+ }
+
+ iree_hal_file_release(source_file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_io_parameter_index_provider_write(
+ iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_string_view_t target_scope, iree_string_view_t target_key,
+ uint64_t target_offset, iree_device_size_t length) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Lookup the parameter metadata and get its backing file.
+ const iree_io_parameter_index_entry_t* target_entry = NULL;
+ iree_hal_file_t* target_file = NULL; // retained
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_index_provider_resolve(
+ provider, device, queue_affinity, target_scope, target_key,
+ IREE_HAL_MEMORY_ACCESS_READ, &target_entry, &target_file));
+
+ // Validate the parameter range is in-bounds.
+ iree_status_t status = iree_io_validate_parameter_range(
+ IREE_HAL_MEMORY_ACCESS_WRITE, target_entry, target_offset, length);
+
+ // Queue the file write from the source buffer.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_write(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+ source_buffer, source_offset, target_file,
+ target_entry->offset + target_offset, length, 0);
+ }
+
+ iree_hal_file_release(target_file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+typedef iree_status_t(IREE_API_PTR* iree_io_parameter_index_file_operation_t)(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
+ iree_device_size_t buffer_offset, iree_device_size_t length,
+ uint32_t flags);
+
+// Returns the index of the smallest value in |values|.
+// Linear scan as the number of values is expected to be small.
+static iree_host_size_t iree_io_select_timeline_bucket(iree_host_size_t count,
+ const uint64_t* values) {
+ IREE_ASSERT_GT(count, 0);
+ uint64_t smallest_value = values[0];
+ iree_host_size_t smallest_index = 0;
+ for (iree_host_size_t i = 1; i < count; ++i) {
+ if (values[i] < smallest_value) {
+ smallest_value = values[i];
+ smallest_index = i;
+ }
+ }
+ return smallest_index;
+}
+
+static iree_status_t iree_io_parameter_index_provider_gather_scatter(
+ iree_io_parameter_index_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t scope, iree_hal_buffer_t* buffer, iree_host_size_t count,
+ iree_io_parameter_enumerator_t enumerator, iree_hal_memory_access_t access,
+ iree_io_parameter_index_file_operation_t operation) {
+ // Decide how many operations we'll keep in-flight at a time. Each concurrent
+ // stream of operations requires its own semaphore.
+ //
+ // NOTE: we expect count == 0 and count == 1 to have been handled by callers
+ // and assume that if we've hit this method we're doing something significant
+ // and it's worth it to do all this.
+ const iree_host_size_t concurrency =
+ iree_min(count, provider->max_concurrent_operations);
+
+ // Distribute operations over each timeline based on how much
+ // I/O is done. It's possible for pathologically bad latency if there are
+ // large and small operations interleaved as all large operations may end up
+ // serialized on one timeline and all small ones on the other.
+ // We distribute by tracking the total bytes outstanding on each timeline and
+ // always placing the next operation on the one with the fewest. This assumes
+ // that all I/O happens at roughly the same speed but if parameters come from
+ // different files on different devices that may not be the case. It's better
+ // than doing nothing, though.
+ uint64_t* timeline_bytes_outstanding =
+ (uint64_t*)iree_alloca(concurrency * sizeof(uint64_t));
+ memset(timeline_bytes_outstanding, 0,
+ concurrency * sizeof(*timeline_bytes_outstanding));
+
+ // Allocate one semaphore per concurrent timeline.
+ IREE_TRACE_ZONE_BEGIN_NAMED(
+ z_init, "iree_io_parameter_index_provider_semaphore_pool_initialize");
+ iree_hal_semaphore_t** timeline_semaphores =
+ (iree_hal_semaphore_t**)iree_alloca(concurrency *
+ sizeof(iree_hal_semaphore_t*));
+ uint64_t* timeline_values =
+ (uint64_t*)iree_alloca(concurrency * sizeof(uint64_t));
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < concurrency; ++i) {
+ timeline_values[i] = 0ull;
+ status = iree_hal_semaphore_create(device, timeline_values[i],
+ &timeline_semaphores[i]);
+ if (!iree_status_is_ok(status)) break;
+ }
+ IREE_TRACE_ZONE_END(z_init);
+
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < count; ++i) {
+ IREE_TRACE_ZONE_BEGIN(z_entry);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, i);
+
+ // Fetch the next parameter to copy and its buffer range.
+ iree_string_view_t key;
+ iree_io_parameter_span_t span;
+ status = enumerator.fn(enumerator.user_data, i, &key, &span);
+
+ // Lookup the parameter metadata and get its backing file.
+ const iree_io_parameter_index_entry_t* entry = NULL;
+ iree_hal_file_t* file = NULL; // retained
+ if (iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, key.data, key.size);
+ status = iree_io_parameter_index_provider_resolve(
+ provider, device, queue_affinity, scope, key, access, &entry,
+ &file);
+ }
+
+ // Validate the parameter range is in-bounds.
+ if (iree_status_is_ok(status)) {
+ status = iree_io_validate_parameter_range(
+ access, entry, span.parameter_offset, span.length);
+ }
+
+ // Queue the file operation.
+ if (iree_status_is_ok(status)) {
+ // Operations are tracked on as many timelines as there is concurrency.
+ // We distribute operations onto timelines based on which has the fewest
+ // outstanding I/O bytes.
+ const iree_host_size_t timeline_index = iree_io_select_timeline_bucket(
+ concurrency, timeline_bytes_outstanding);
+ timeline_bytes_outstanding[timeline_index] += span.length;
+ iree_hal_semaphore_t* timeline_semaphore =
+ timeline_semaphores[timeline_index];
+ uint64_t previous_timeline_value = timeline_values[timeline_index];
+ uint64_t next_timeline_value = ++timeline_values[timeline_index];
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, (uint64_t)timeline_index);
+
+ // The first wave of operations all wait on the provided wait
+ // semaphores. All others wait on their own internal concurrent
+ // timelines.
+ iree_hal_semaphore_list_t entry_wait_semaphore_list;
+ if (i < concurrency) {
+ entry_wait_semaphore_list = wait_semaphore_list;
+ } else {
+ entry_wait_semaphore_list = (iree_hal_semaphore_list_t){
+ .count = 1,
+ .semaphores = &timeline_semaphore,
+ .payload_values = &previous_timeline_value,
+ };
+ }
+
+ // All operations signal their concurrency timelines and we'll put a
+ // barrier at the end so that we can join them all.
+ iree_hal_semaphore_list_t entry_signal_semaphore_list = {
+ .count = 1,
+ .semaphores = &timeline_semaphore,
+ .payload_values = &next_timeline_value,
+ };
+
+ // Perform the operation.
+ status = operation(device, queue_affinity, entry_wait_semaphore_list,
+ entry_signal_semaphore_list, file,
+ entry->offset + span.parameter_offset, buffer,
+ span.buffer_offset, span.length, 0);
+ }
+
+ iree_hal_file_release(file);
+
+ IREE_TRACE_ZONE_END(z_entry);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+
+ // Join all concurrent timelines and continue the user-provided timeline.
+ if (iree_status_is_ok(status)) {
+ iree_hal_semaphore_list_t join_semaphore_list = {
+ .count = concurrency,
+ .semaphores = timeline_semaphores,
+ .payload_values = timeline_values,
+ };
+ status = iree_hal_device_queue_barrier(
+ device, queue_affinity, join_semaphore_list, signal_semaphore_list);
+ }
+
+ // Release temporary semaphores.
+ IREE_TRACE_ZONE_BEGIN_NAMED(
+ z_deinit, "iree_io_parameter_index_provider_semaphore_pool_deinitialize");
+ for (iree_host_size_t i = 0; i < concurrency; ++i) {
+ iree_hal_semaphore_release(timeline_semaphores[i]);
+ }
+ IREE_TRACE_ZONE_END(z_deinit);
+
+ return status;
+}
+
+static iree_status_t iree_io_parameter_index_file_read(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
+ iree_device_size_t buffer_offset, iree_device_size_t length,
+ uint32_t flags) {
+ return iree_hal_device_queue_read(device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, file, file_offset,
+ buffer, buffer_offset, length, flags);
+}
+
+static iree_status_t iree_io_parameter_index_provider_gather(
+ iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_t status = iree_io_parameter_index_provider_gather_scatter(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_scope, target_buffer, count, enumerator,
+ IREE_HAL_MEMORY_ACCESS_READ, iree_io_parameter_index_file_read);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_io_parameter_index_file_write(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
+ iree_device_size_t buffer_offset, iree_device_size_t length,
+ uint32_t flags) {
+ return iree_hal_device_queue_write(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+ buffer, buffer_offset, file, file_offset, length, flags);
+}
+
+static iree_status_t iree_io_parameter_index_provider_scatter(
+ iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
+ iree_io_parameter_index_provider_t* provider =
+ iree_io_parameter_index_provider_cast(base_provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_t status = iree_io_parameter_index_provider_gather_scatter(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, target_scope, source_buffer, count, enumerator,
+ IREE_HAL_MEMORY_ACCESS_WRITE, iree_io_parameter_index_file_write);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static const iree_io_parameter_provider_vtable_t
+ iree_io_parameter_index_provider_vtable = {
+ .destroy = iree_io_parameter_index_provider_destroy,
+ .notify = iree_io_parameter_index_provider_notify,
+ .query_support = iree_io_parameter_index_provider_query_support,
+ .load = iree_io_parameter_index_provider_load,
+ .read = iree_io_parameter_index_provider_read,
+ .write = iree_io_parameter_index_provider_write,
+ .gather = iree_io_parameter_index_provider_gather,
+ .scatter = iree_io_parameter_index_provider_scatter,
+};
diff --git a/runtime/src/iree/io/parameter_index_provider.h b/runtime/src/iree/io/parameter_index_provider.h
new file mode 100644
index 0000000..2864a55
--- /dev/null
+++ b/runtime/src/iree/io/parameter_index_provider.h
@@ -0,0 +1,42 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_PARAMETER_INDEX_PROVIDER_H_
+#define IREE_IO_PARAMETER_INDEX_PROVIDER_H_
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/io/parameter_index.h"
+#include "iree/io/parameter_provider.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Reasonable default for the `max_concurrent_operations` parameter.
+#define IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS 16
+
+// Creates a parameter provider serving from the provided |index|.
+// As parameters are operated on their files will be registered with the devices
+// they are used on and cached for future requests.
+//
+// |max_concurrent_operations| can be used to limit how many file operations as
+// part of a gather or scatter are allowed to be in-flight at a time. A lower
+// number can reduce system resource requirements during the operation (less
+// transient memory required, etc) while increasing latency (lower I/O
+// utilization).
+IREE_API_EXPORT iree_status_t iree_io_parameter_index_provider_create(
+ iree_string_view_t scope, iree_io_parameter_index_t* index,
+ iree_host_size_t max_concurrent_operations, iree_allocator_t host_allocator,
+ iree_io_parameter_provider_t** out_provider);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_PARAMETER_INDEX_PROVIDER_H_
diff --git a/runtime/src/iree/io/parameter_provider.c b/runtime/src/iree/io/parameter_provider.c
new file mode 100644
index 0000000..064c3aa
--- /dev/null
+++ b/runtime/src/iree/io/parameter_provider.c
@@ -0,0 +1,187 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/parameter_provider.h"
+
+IREE_API_EXPORT void iree_io_parameter_provider_retain(
+ iree_io_parameter_provider_t* provider) {
+ if (IREE_LIKELY(provider)) {
+ iree_atomic_ref_count_inc(&provider->ref_count);
+ }
+}
+
+IREE_API_EXPORT void iree_io_parameter_provider_release(
+ iree_io_parameter_provider_t* provider) {
+ if (IREE_LIKELY(provider) &&
+ iree_atomic_ref_count_dec(&provider->ref_count) == 1) {
+ provider->vtable->destroy(provider);
+ }
+}
+
+IREE_API_EXPORT iree_status_t
+iree_io_parameter_provider_notify(iree_io_parameter_provider_t* provider,
+ iree_io_parameter_provider_signal_t signal) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE({
+ switch (signal) {
+ case IREE_IO_PARAMETER_PROVIDER_SIGNAL_RESUME:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "RESUME");
+ break;
+ case IREE_IO_PARAMETER_PROVIDER_SIGNAL_SUSPEND:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "SUSPEND");
+ break;
+ case IREE_IO_PARAMETER_PROVIDER_SIGNAL_LOW_MEMORY:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "LOW_MEMORY");
+ break;
+ default:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "(unknown)");
+ break;
+ }
+ });
+ iree_status_t status = provider->vtable->notify(provider, signal);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT bool iree_io_parameter_provider_query_support(
+ iree_io_parameter_provider_t* provider, iree_string_view_t scope) {
+ IREE_ASSERT_ARGUMENT(provider);
+ return provider->vtable->query_support(provider, scope);
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_load(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_params_t target_params,
+ iree_device_size_t length,
+ iree_hal_buffer_t** IREE_RESTRICT out_target_buffer) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = provider->vtable->load(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_scope, source_key, source_offset,
+ target_params, length, out_target_buffer);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_read(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = provider->vtable->read(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_scope, source_key, source_offset,
+ target_buffer, target_offset, length);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_write(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_string_view_t target_scope, iree_string_view_t target_key,
+ uint64_t target_offset, iree_device_size_t length) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = provider->vtable->write(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_buffer, source_offset, target_scope,
+ target_key, target_offset, length);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_gather(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
+ if (count == 0) {
+ // Preserve the timeline when there's no work to do.
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_device_queue_barrier(device, queue_affinity,
+ wait_semaphore_list,
+ signal_semaphore_list));
+ } else if (count == 1) {
+ // One span is just a read.
+ iree_string_view_t key = iree_string_view_empty();
+ iree_io_parameter_span_t span = {0};
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, enumerator.fn(enumerator.user_data, 0, &key, &span));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_provider_read(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_scope, key, span.parameter_offset,
+ target_buffer, span.buffer_offset, span.length));
+ } else {
+ // Full multi-span gather.
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, provider->vtable->gather(provider, device, queue_affinity,
+ wait_semaphore_list, signal_semaphore_list,
+ source_scope, target_buffer, count,
+ enumerator));
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_scatter(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
+ if (count == 0) {
+ // Preserve the timeline when there's no work to do.
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_device_queue_barrier(device, queue_affinity,
+ wait_semaphore_list,
+ signal_semaphore_list));
+ } else if (count == 1) {
+ // One span is just a write.
+ iree_string_view_t key = iree_string_view_empty();
+ iree_io_parameter_span_t span = {0};
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, enumerator.fn(enumerator.user_data, 0, &key, &span));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_provider_write(
+ provider, device, queue_affinity, wait_semaphore_list,
+ signal_semaphore_list, source_buffer, span.buffer_offset,
+ target_scope, key, span.parameter_offset, span.length));
+ } else {
+ // Full multi-span scatter.
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, provider->vtable->scatter(provider, device, queue_affinity,
+ wait_semaphore_list,
+ signal_semaphore_list, source_buffer,
+ target_scope, count, enumerator));
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
diff --git a/runtime/src/iree/io/parameter_provider.h b/runtime/src/iree/io/parameter_provider.h
new file mode 100644
index 0000000..e398037
--- /dev/null
+++ b/runtime/src/iree/io/parameter_provider.h
@@ -0,0 +1,241 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_PARAMETER_PROVIDER_H_
+#define IREE_IO_PARAMETER_PROVIDER_H_
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_io_parameter_provider_t
+//===----------------------------------------------------------------------===//
+
+// Indicates an event signaled from the hosting program.
+typedef enum iree_io_parameter_provider_signal_e {
+ // Program is resuming from a suspended state.
+ // Providers may reallocate memory for pools and caches.
+ IREE_IO_PARAMETER_PROVIDER_SIGNAL_RESUME = 0,
+ // Program is entering a suspended state.
+ // Providers should drop any transient memory that is possible to reallocate
+ // upon resume.
+ IREE_IO_PARAMETER_PROVIDER_SIGNAL_SUSPEND = 1,
+ // Program has received a low memory alert.
+ // Providers must aggressively drop all possible memory even if expensive to
+ // rematerialize it. On some platforms this is sent as a threat that if
+ // sufficient memory is not unwired/freed ASAP the process will be killed.
+ IREE_IO_PARAMETER_PROVIDER_SIGNAL_LOW_MEMORY = 2,
+} iree_io_parameter_provider_signal_t;
+
+typedef struct iree_io_parameter_span_t {
+ uint64_t parameter_offset;
+ iree_device_size_t buffer_offset;
+ iree_device_size_t length;
+} iree_io_parameter_span_t;
+
+// Interface for providers of parameter storage and caching.
+// Parameters are referenced by a scope (conceptually a file, group, or table)
+// and a scope-unique key.
+//
+// Each provider implementation can handle any number of scope types. Users are
+// expected to query for support with iree_io_parameter_provider_query_support
+// prior to performing operations.
+//
+// Thread-safe: a provider may be shared by several contexts simultaneously.
+// Behavior is currently undefined if multiple contexts attempt to read or write
+// the same parameters concurrently. Future revisions may require that providers
+// track pending operations per parameter and sequencing appropriately.
+typedef struct iree_io_parameter_provider_t iree_io_parameter_provider_t;
+
+// Retains the given |provider| for the caller.
+IREE_API_EXPORT void iree_io_parameter_provider_retain(
+ iree_io_parameter_provider_t* provider);
+
+// Releases the given |provider| from the caller.
+IREE_API_EXPORT void iree_io_parameter_provider_release(
+ iree_io_parameter_provider_t* provider);
+
+// Notifies the provider of an event from the hosting program.
+// Providers can ignore notifications at their peril.
+IREE_API_EXPORT iree_status_t
+iree_io_parameter_provider_notify(iree_io_parameter_provider_t* provider,
+ iree_io_parameter_provider_signal_t signal);
+
+// Returns true if the given |scope| is supported by |provider|.
+IREE_API_EXPORT bool iree_io_parameter_provider_query_support(
+ iree_io_parameter_provider_t* provider, iree_string_view_t scope);
+
+// Loads a parameter from |provider| for use on |device|.
+// |source_scope| and |source_key| define the parameter and |target_params|
+// defines how the buffer is to be allocated. If the parameter is smaller than
+// |length| any remaining space must be zero-filled.
+//
+// If the implementation is able to meet the expected |target_params| with an
+// existing buffer it may be returned without a new allocation. If access allows
+// implementations are allowed to return mapped memory that may be shared by
+// other users within the same process or across processes.
+//
+// Implementations that have no optimized load/import path can implement this
+// with iree_hal_device_queue_alloca and iree_io_parameter_provider_read.
+//
+// Returns IREE_STATUS_NOT_FOUND if the parameter is not found.
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_load(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_params_t target_params,
+ iree_device_size_t length,
+ iree_hal_buffer_t** IREE_RESTRICT out_target_buffer);
+
+// Reads a parameter from |provider| for use on |device|.
+// |source_scope| and |source_key| define the parameter to be read into
+// |target_buffer| at |target_offset|. If the parameter is smaller than
+// |length| any remaining space must be zero-filled.
+//
+// Returns IREE_STATUS_NOT_FOUND if the parameter is not found.
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_read(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length);
+
+// Writes a parameter to |provider| from |device|.
+// The parameter data is sourced from |source_buffer| at |source_offset| and
+// |target_scope| and |target_key| define which parameter is being written.
+//
+// Returns IREE_STATUS_NOT_FOUND if the parameter is not found.
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_write(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_string_view_t target_scope, iree_string_view_t target_key,
+ uint64_t target_offset, iree_device_size_t length);
+
+typedef iree_status_t(IREE_API_PTR* iree_io_parameter_enumerator_fn_t)(
+ void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
+ iree_io_parameter_span_t* out_span);
+
+typedef struct iree_io_parameter_enumerator_t {
+ // Callback function pointer.
+ iree_io_parameter_enumerator_fn_t fn;
+ // User data passed to the callback function. Unowned.
+ void* user_data;
+} iree_io_parameter_enumerator_t;
+
+// Gathers zero or more spans from |provider| into the given |target_buffer|.
+// The |enumerator| defines the source keys in |source_scope| and the offset and
+// length in the |target_buffer| of each span. For any parameter is smaller than
+// the length specified by the span any remaining space must be zero-filled.
+// Multiple spans may reference the same source parameter but behavior is
+// undefined if multiple span target ranges overlap.
+//
+// Returns IREE_STATUS_NOT_FOUND if any parameter is not found.
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_gather(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator);
+
+// Scatters zero or more spans to |provider| from the given |source_buffer|.
+// The |enumerator| defines the target keys in |target_scope| and the offset and
+// length in the |source_buffer| of each span to scatter. Multiple spans may
+// reference source ranges that overlap but behavior is undefined if multiple
+// spans share the same target parameter.
+//
+// Returns IREE_STATUS_NOT_FOUND if any parameter is not found.
+IREE_API_EXPORT iree_status_t iree_io_parameter_provider_scatter(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator);
+
+//===----------------------------------------------------------------------===//
+// iree_io_parameter_provider_t implementation details
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_io_parameter_provider_vtable_t {
+ void(IREE_API_PTR* destroy)(
+ iree_io_parameter_provider_t* IREE_RESTRICT provider);
+
+ iree_status_t(IREE_API_PTR* notify)(
+ iree_io_parameter_provider_t* provider,
+ iree_io_parameter_provider_signal_t signal);
+
+ bool(IREE_API_PTR* query_support)(iree_io_parameter_provider_t* provider,
+ iree_string_view_t scope);
+
+ iree_status_t(IREE_API_PTR* load)(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_params_t target_params,
+ iree_device_size_t length,
+ iree_hal_buffer_t** IREE_RESTRICT out_target_buffer);
+
+ iree_status_t(IREE_API_PTR* read)(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_string_view_t source_key,
+ uint64_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length);
+
+ iree_status_t(IREE_API_PTR* write)(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_string_view_t target_scope, iree_string_view_t target_key,
+ uint64_t target_offset, iree_device_size_t length);
+
+ iree_status_t(IREE_API_PTR* gather)(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator);
+
+ iree_status_t(IREE_API_PTR* scatter)(
+ iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator);
+} iree_io_parameter_provider_vtable_t;
+
+struct iree_io_parameter_provider_t {
+ iree_atomic_ref_count_t ref_count;
+ const iree_io_parameter_provider_vtable_t* vtable;
+};
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_PARAMETER_PROVIDER_H_
diff --git a/runtime/src/iree/io/scope_map.c b/runtime/src/iree/io/scope_map.c
new file mode 100644
index 0000000..806397b
--- /dev/null
+++ b/runtime/src/iree/io/scope_map.c
@@ -0,0 +1,77 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/io/scope_map.h"
+
+IREE_API_EXPORT void iree_io_scope_map_initialize(
+ iree_allocator_t host_allocator, iree_io_scope_map_t* out_scope_map) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ out_scope_map->host_allocator = host_allocator;
+ out_scope_map->count = 0;
+ out_scope_map->capacity = 0;
+ out_scope_map->entries = NULL;
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT void iree_io_scope_map_deinitialize(
+ iree_io_scope_map_t* scope_map) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator = scope_map->host_allocator;
+ for (iree_host_size_t i = 0; i < scope_map->count; ++i) {
+ iree_io_scope_map_entry_t* entry = scope_map->entries[i];
+ iree_io_parameter_index_release(entry->index);
+ iree_allocator_free(host_allocator, entry);
+ }
+ iree_allocator_free(host_allocator, scope_map->entries);
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT iree_status_t iree_io_scope_map_lookup(
+ iree_io_scope_map_t* scope_map, iree_string_view_t scope,
+ iree_io_parameter_index_t** out_index) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, scope.data, scope.size);
+
+ for (iree_host_size_t i = 0; i < scope_map->count; ++i) {
+ iree_io_scope_map_entry_t* entry = scope_map->entries[i];
+ if (iree_string_view_equal(scope, entry->scope)) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "hit");
+ *out_index = entry->index;
+ return iree_ok_status();
+ }
+ }
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "miss");
+
+ if (scope_map->count == scope_map->capacity) {
+ iree_host_size_t new_capacity = iree_max(8, scope_map->capacity * 2);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_realloc(
+ scope_map->host_allocator,
+ new_capacity * sizeof(iree_io_scope_map_entry_t*),
+ (void**)&scope_map->entries));
+ scope_map->capacity = new_capacity;
+ }
+
+ iree_io_scope_map_entry_t* entry = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(scope_map->host_allocator,
+ sizeof(*entry) + scope.size, (void**)&entry));
+ entry->scope =
+ iree_make_string_view((const char*)entry + sizeof(*entry), scope.size);
+ memcpy((char*)entry->scope.data, scope.data, scope.size);
+
+ iree_status_t status =
+ iree_io_parameter_index_create(scope_map->host_allocator, &entry->index);
+
+ if (iree_status_is_ok(status)) {
+ scope_map->entries[scope_map->count++] = entry;
+ *out_index = entry->index;
+ } else {
+ iree_allocator_free(scope_map->host_allocator, entry);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/io/scope_map.h b/runtime/src/iree/io/scope_map.h
new file mode 100644
index 0000000..30e7377
--- /dev/null
+++ b/runtime/src/iree/io/scope_map.h
@@ -0,0 +1,45 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_IO_SCOPE_MAP_H_
+#define IREE_IO_SCOPE_MAP_H_
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/io/parameter_index.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct iree_io_scope_map_entry_t {
+ iree_string_view_t scope;
+ iree_io_parameter_index_t* index;
+} iree_io_scope_map_entry_t;
+
+typedef struct iree_io_scope_map_t {
+ iree_allocator_t host_allocator;
+ iree_host_size_t count;
+ iree_host_size_t capacity;
+ iree_io_scope_map_entry_t** entries;
+} iree_io_scope_map_t;
+
+IREE_API_EXPORT void iree_io_scope_map_initialize(
+ iree_allocator_t host_allocator, iree_io_scope_map_t* out_scope_map);
+
+IREE_API_EXPORT void iree_io_scope_map_deinitialize(
+ iree_io_scope_map_t* scope_map);
+
+IREE_API_EXPORT iree_status_t iree_io_scope_map_lookup(
+ iree_io_scope_map_t* scope_map, iree_string_view_t scope,
+ iree_io_parameter_index_t** out_index);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_IO_SCOPE_MAP_H_
diff --git a/runtime/src/iree/modules/io/BUILD.bazel b/runtime/src/iree/modules/io/BUILD.bazel
new file mode 100644
index 0000000..522ca5d
--- /dev/null
+++ b/runtime/src/iree/modules/io/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/runtime/src/iree/modules/io/CMakeLists.txt b/runtime/src/iree/modules/io/CMakeLists.txt
new file mode 100644
index 0000000..1724cbd
--- /dev/null
+++ b/runtime/src/iree/modules/io/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/modules/io/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/modules/io/parameters/BUILD.bazel b/runtime/src/iree/modules/io/parameters/BUILD.bazel
new file mode 100644
index 0000000..00cbd45
--- /dev/null
+++ b/runtime/src/iree/modules/io/parameters/BUILD.bazel
@@ -0,0 +1,33 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_runtime_cc_library(
+ name = "parameters",
+ srcs = [
+ "module.c",
+ ],
+ hdrs = [
+ "module.h",
+ ],
+ textual_hdrs = [
+ "exports.inl",
+ ],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/io:parameter_provider",
+ "//runtime/src/iree/modules/hal:types",
+ "//runtime/src/iree/vm",
+ ],
+)
diff --git a/runtime/src/iree/modules/io/parameters/CMakeLists.txt b/runtime/src/iree/modules/io/parameters/CMakeLists.txt
new file mode 100644
index 0000000..52c0c10
--- /dev/null
+++ b/runtime/src/iree/modules/io/parameters/CMakeLists.txt
@@ -0,0 +1,31 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/modules/io/parameters/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ parameters
+ HDRS
+ "module.h"
+ TEXTUAL_HDRS
+ "exports.inl"
+ SRCS
+ "module.c"
+ DEPS
+ iree::base
+ iree::hal
+ iree::io::parameter_provider
+ iree::modules::hal::types
+ iree::vm
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/modules/io/parameters/exports.inl b/runtime/src/iree/modules/io/parameters/exports.inl
new file mode 100644
index 0000000..47984c4
--- /dev/null
+++ b/runtime/src/iree/modules/io/parameters/exports.inl
@@ -0,0 +1,33 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===----------------------------------------------------------------------===//
+//
+// ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████
+// ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
+// ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███
+// ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
+// ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
+//
+//===----------------------------------------------------------------------===//
+//
+// This file will be auto generated from io_parameters.imports.mlir in the
+// future; for now it's modified by hand but with strict alphabetical sorting
+// required. The order of these functions must be sorted ascending by name in a
+// way compatible with iree_string_view_compare.
+//
+// Users are meant to `#define EXPORT_FN` to be able to access the information.
+// #define EXPORT_FN(name, target_fn, arg_type, ret_type)
+
+// clang-format off
+
+EXPORT_FN("gather", iree_io_parameters_module_gather, rIrrrrrrr, v)
+EXPORT_FN("load", iree_io_parameters_module_load, rIrrrrIIiiI, r)
+EXPORT_FN("read", iree_io_parameters_module_read, rIrrrrIrII, v)
+EXPORT_FN("scatter", iree_io_parameters_module_scatter, rIrrrrrrr, v)
+EXPORT_FN("write", iree_io_parameters_module_write, rIrrrrIrII, v)
+
+// clang-format on
diff --git a/runtime/src/iree/modules/io/parameters/module.c b/runtime/src/iree/modules/io/parameters/module.c
new file mode 100644
index 0000000..697ffb6
--- /dev/null
+++ b/runtime/src/iree/modules/io/parameters/module.c
@@ -0,0 +1,544 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/modules/io/parameters/module.h"
+
+#include "iree/modules/hal/types.h"
+
+#define IREE_IO_PARAMETERS_MODULE_VERSION_0_0 0x00000000u
+#define IREE_IO_PARAMETERS_MODULE_VERSION_LATEST \
+ IREE_IO_PARAMETERS_MODULE_VERSION_0_0
+
+//===----------------------------------------------------------------------===//
+// Module type definitions
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_io_parameters_module_t {
+ iree_allocator_t host_allocator;
+ iree_host_size_t provider_count;
+ iree_io_parameter_provider_t* providers[];
+} iree_io_parameters_module_t;
+
+#define IREE_IO_PARAMETERS_MODULE_CAST(module) \
+ (iree_io_parameters_module_t*)((uint8_t*)(module) + \
+ iree_vm_native_module_size())
+
+typedef struct iree_io_parameters_module_state_t {
+ iree_allocator_t host_allocator;
+} iree_io_parameters_module_state_t;
+
+static void IREE_API_PTR iree_io_parameters_module_destroy(void* base_module) {
+ iree_io_parameters_module_t* module =
+ IREE_IO_PARAMETERS_MODULE_CAST(base_module);
+ for (iree_host_size_t i = 0; i < module->provider_count; ++i) {
+ iree_io_parameter_provider_release(module->providers[i]);
+ }
+ module->provider_count = 0;
+}
+
+static iree_status_t IREE_API_PTR iree_io_parameters_module_alloc_state(
+ void* self, iree_allocator_t host_allocator,
+ iree_vm_module_state_t** out_module_state) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_parameters_module_state_t* state = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
+ memset(state, 0, sizeof(*state));
+ state->host_allocator = host_allocator;
+
+ *out_module_state = (iree_vm_module_state_t*)state;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void IREE_API_PTR iree_io_parameters_module_free_state(
+ void* self, iree_vm_module_state_t* module_state) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_parameters_module_state_t* state =
+ (iree_io_parameters_module_state_t*)module_state;
+ iree_allocator_free(state->host_allocator, state);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t IREE_API_PTR iree_io_parameters_module_notify(
+ void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) {
+ iree_io_parameters_module_t* module = IREE_IO_PARAMETERS_MODULE_CAST(self);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_io_parameter_provider_signal_t provider_signal;
+ switch (signal) {
+ case IREE_VM_SIGNAL_RESUME:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "RESUME");
+ provider_signal = IREE_IO_PARAMETER_PROVIDER_SIGNAL_RESUME;
+ break;
+ case IREE_VM_SIGNAL_SUSPEND:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "SUSPEND");
+ provider_signal = IREE_IO_PARAMETER_PROVIDER_SIGNAL_SUSPEND;
+ break;
+ case IREE_VM_SIGNAL_LOW_MEMORY:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "LOW_MEMORY");
+ provider_signal = IREE_IO_PARAMETER_PROVIDER_SIGNAL_LOW_MEMORY;
+ break;
+ default:
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "(unhandled)");
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+ for (iree_host_size_t i = 0; i < module->provider_count; ++i) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_provider_notify(module->providers[i],
+ provider_signal));
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Scans the provider list to find one that handles |scope|.
+static iree_status_t iree_io_parameters_module_resolve_provider(
+ iree_io_parameters_module_t* module, iree_string_view_t scope,
+ iree_io_parameter_provider_t** out_provider) {
+ for (iree_host_size_t i = 0; i < module->provider_count; ++i) {
+ iree_io_parameter_provider_t* provider = module->providers[i];
+ if (iree_io_parameter_provider_query_support(provider, scope)) {
+ *out_provider = provider;
+ return iree_ok_status();
+ }
+ }
+ return iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "no provider registered that handles scopes like '%.*s'", (int)scope.size,
+ scope.data);
+}
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+// Casts a VM value to a HAL device size.
+static iree_device_size_t iree_hal_cast_device_size(int64_t value) {
+ // TODO(benvanik): make this return status and check for overflow if device
+ // size is 32-bits.
+ return (iree_device_size_t)value;
+}
+
+//===----------------------------------------------------------------------===//
+// Exported functions
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_load, //
+ iree_io_parameters_module_state_t, //
+ rIrrrrIIiiI, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_vm_buffer_t* source_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r4, &source_scope));
+ iree_vm_buffer_t* source_key = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &source_key));
+ uint64_t source_offset = args->i6;
+ iree_hal_queue_affinity_t target_queue_affinity =
+ (iree_hal_queue_affinity_t)args->i7;
+ iree_hal_memory_type_t target_memory_types = (iree_hal_memory_type_t)args->i8;
+ iree_hal_buffer_usage_t target_buffer_usage =
+ (iree_hal_buffer_usage_t)args->i9;
+ iree_device_size_t length = iree_hal_cast_device_size(args->i10);
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(source_scope), &provider));
+
+ const iree_hal_buffer_params_t target_params = {
+ .type = target_memory_types,
+ .usage = target_buffer_usage,
+ .queue_affinity = target_queue_affinity,
+ };
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameter_provider_load(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence),
+ iree_vm_buffer_as_string(source_scope),
+ iree_vm_buffer_as_string(source_key), source_offset, target_params,
+ length, &target_buffer));
+
+ rets->r0 = iree_hal_buffer_move_ref(target_buffer);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_read, //
+ iree_io_parameters_module_state_t, //
+ rIrrrrIrII, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_vm_buffer_t* source_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r4, &source_scope));
+ iree_vm_buffer_t* source_key = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &source_key));
+ uint64_t source_offset = args->i6;
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r7, &target_buffer));
+ iree_device_size_t target_offset = iree_hal_cast_device_size(args->i8);
+ iree_device_size_t length = iree_hal_cast_device_size(args->i9);
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(source_scope), &provider));
+
+ return iree_io_parameter_provider_read(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence),
+ iree_vm_buffer_as_string(source_scope),
+ iree_vm_buffer_as_string(source_key), source_offset, target_buffer,
+ target_offset, length);
+}
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_write, //
+ iree_io_parameters_module_state_t, //
+ rIrrrrIrII, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_vm_buffer_t* target_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r4, &target_scope));
+ iree_vm_buffer_t* target_key = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &target_key));
+ uint64_t target_offset = args->i6;
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r7, &source_buffer));
+ iree_device_size_t source_offset = iree_hal_cast_device_size(args->i8);
+ iree_device_size_t length = iree_hal_cast_device_size(args->i9);
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(target_scope), &provider));
+
+ return iree_io_parameter_provider_write(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence), source_buffer, source_offset,
+ iree_vm_buffer_as_string(target_scope),
+ iree_vm_buffer_as_string(target_key), target_offset, length);
+}
+
+typedef struct iree_io_parameters_string_entry_t {
+ uint32_t offset;
+ uint32_t length;
+} iree_io_parameters_string_entry_t;
+
+typedef struct iree_io_parameters_span_entry_t {
+ uint64_t parameter_offset;
+ uint64_t buffer_offset;
+ uint64_t length;
+} iree_io_parameters_span_entry_t;
+
+typedef struct iree_io_parameters_indirect_args_t {
+ iree_host_size_t count;
+ const iree_io_parameters_string_entry_t* string_table;
+ iree_const_byte_span_t string_data;
+ const iree_io_parameters_span_entry_t* spans;
+} iree_io_parameters_indirect_args_t;
+
+static iree_status_t iree_io_parameters_prepare_indirect_args(
+ iree_vm_buffer_t* key_table, iree_vm_buffer_t* key_data,
+ iree_vm_buffer_t* spans, iree_io_parameters_indirect_args_t* out_args) {
+ // Span count is defined by the number of entries that storage contains.
+ if (iree_vm_buffer_length(spans) % sizeof(iree_io_parameters_span_entry_t) !=
+ 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "buffer span storage must be aligned to "
+ "iree_io_parameters_span_entry_t");
+ }
+ const iree_host_size_t count =
+ iree_vm_buffer_length(spans) / sizeof(iree_io_parameters_span_entry_t);
+
+ // Verify there's enough space in the key string table for the entries we
+ // need.
+ if (iree_vm_buffer_length(key_table) <
+ count * sizeof(iree_io_parameters_string_entry_t)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "key string table must have enough data to service all defined spans");
+ }
+
+ // Map string table; note that the offsets are validated during enumeration.
+ iree_const_byte_span_t key_table_ptr = iree_const_byte_span_empty();
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(key_table, 0,
+ iree_vm_buffer_length(key_table),
+ sizeof(uint32_t), &key_table_ptr));
+ out_args->string_table =
+ (const iree_io_parameters_string_entry_t*)key_table_ptr.data;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_map_ro(key_data, 0, iree_vm_buffer_length(key_data),
+ sizeof(char), &out_args->string_data));
+
+ // Map span data; the offsets/lengths are validated in the parameter provider
+ // implementation.
+ iree_host_size_t span_list_size =
+ count * sizeof(iree_io_parameters_span_entry_t);
+ iree_const_byte_span_t span_list_ptr = iree_const_byte_span_empty();
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(spans, 0, span_list_size,
+ sizeof(uint64_t), &span_list_ptr));
+ out_args->spans = (const iree_io_parameters_span_entry_t*)span_list_ptr.data;
+
+ out_args->count = count;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_io_parameters_resolve_string(
+ iree_io_parameters_string_entry_t key, iree_const_byte_span_t string_data,
+ iree_string_view_t* out_key) {
+ *out_key = iree_string_view_empty();
+
+ // Check if the start of the range runs off the end of the buffer.
+ if (IREE_UNLIKELY(key.offset > string_data.data_length)) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "attempted to access an address off the end of the valid buffer range "
+ "(offset=%u, length=%u, data_capacity=%" PRIhsz ")",
+ key.offset, key.length, string_data.data_length);
+ }
+
+ if (key.length == 0) {
+ // Fine to have a zero length.
+ return iree_ok_status();
+ }
+
+ // Check if the end runs over the allocation.
+ uint32_t end = key.offset + key.length;
+ if (IREE_UNLIKELY(end > string_data.data_length)) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "attempted to access an address outside of the valid buffer range "
+ "(offset=%u, length=%u, end(inc)=%u, data_capacity=%" PRIhsz ")",
+ key.offset, key.length, end - 1, string_data.data_length);
+ }
+
+ out_key->data = (const char*)string_data.data + key.offset;
+ out_key->size = (iree_host_size_t)key.length;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_io_parameters_indirect_enumerator(
+ void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
+ iree_io_parameter_span_t* out_span) {
+ const iree_io_parameters_indirect_args_t* args =
+ (const iree_io_parameters_indirect_args_t*)user_data;
+ if (IREE_UNLIKELY(i >= args->count)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "parameter out of bounds");
+ }
+ IREE_RETURN_IF_ERROR(iree_io_parameters_resolve_string(
+ args->string_table[i], args->string_data, out_key));
+ const iree_io_parameters_span_entry_t span = args->spans[i];
+ out_span->parameter_offset = span.parameter_offset;
+ out_span->buffer_offset = iree_hal_cast_device_size(span.buffer_offset);
+ out_span->length = iree_hal_cast_device_size(span.length);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_gather, //
+ iree_io_parameters_module_state_t, //
+ rIrrrrrrr, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_vm_buffer_t* source_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r4, &source_scope));
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r5, &target_buffer));
+ iree_vm_buffer_t* key_table = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r6, &key_table));
+ iree_vm_buffer_t* key_data = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r7, &key_data));
+ iree_vm_buffer_t* spans = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r8, &spans));
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(source_scope), &provider));
+
+ iree_io_parameters_indirect_args_t enumerator_args;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_prepare_indirect_args(
+ key_table, key_data, spans, &enumerator_args));
+ iree_io_parameter_enumerator_t enumerator = {
+ .fn = iree_io_parameters_indirect_enumerator,
+ .user_data = &enumerator_args,
+ };
+ return iree_io_parameter_provider_gather(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence),
+ iree_vm_buffer_as_string(source_scope), target_buffer,
+ enumerator_args.count, enumerator);
+}
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_scatter, //
+ iree_io_parameters_module_state_t, //
+ rIrrrrrrr, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &source_buffer));
+ iree_vm_buffer_t* target_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r5, &target_scope));
+ iree_vm_buffer_t* key_table = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r6, &key_table));
+ iree_vm_buffer_t* key_data = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r7, &key_data));
+ iree_vm_buffer_t* spans = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r8, &spans));
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(target_scope), &provider));
+
+ iree_io_parameters_indirect_args_t enumerator_args;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_prepare_indirect_args(
+ key_table, key_data, spans, &enumerator_args));
+ iree_io_parameter_enumerator_t enumerator = {
+ .fn = iree_io_parameters_indirect_enumerator,
+ .user_data = &enumerator_args,
+ };
+ return iree_io_parameter_provider_scatter(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence), source_buffer,
+ iree_vm_buffer_as_string(target_scope), enumerator_args.count,
+ enumerator);
+}
+
+//===----------------------------------------------------------------------===//
+// VM module interface implementation
+//===----------------------------------------------------------------------===//
+
+// NOTE: this must match the ordering of the iree_io_parameters_module_exports_
+// table.
+static const iree_vm_native_function_ptr_t iree_io_parameters_module_funcs_[] =
+ {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
+ { \
+ .shim = (iree_vm_native_function_shim_t) \
+ iree_vm_shim_##arg_types##_##ret_types, \
+ .target = (iree_vm_native_function_target_t)(target_fn), \
+ },
+#include "iree/modules/io/parameters/exports.inl" // IWYU pragma: keep
+#undef EXPORT_FN
+};
+
+// NOTE: 0 length, but can't express that in C.
+static const iree_vm_native_import_descriptor_t
+ iree_io_parameters_module_imports_[1];
+
+static const iree_vm_native_export_descriptor_t
+ iree_io_parameters_module_exports_[] = {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
+ { \
+ .local_name = iree_string_view_literal(name), \
+ .calling_convention = \
+ iree_string_view_literal("0" #arg_types "_" #ret_types), \
+ .attr_count = 0, \
+ .attrs = NULL, \
+ },
+#include "iree/modules/io/parameters/exports.inl" // IWYU pragma: keep
+#undef EXPORT_FN
+};
+static_assert(IREE_ARRAYSIZE(iree_io_parameters_module_funcs_) ==
+ IREE_ARRAYSIZE(iree_io_parameters_module_exports_),
+ "function pointer table must be 1:1 with exports");
+
+static const iree_vm_native_module_descriptor_t
+ iree_io_parameters_module_descriptor_ = {
+ .name = iree_string_view_literal("io_parameters"),
+ .version = IREE_IO_PARAMETERS_MODULE_VERSION_LATEST,
+ .attr_count = 0,
+ .attrs = NULL,
+ .dependency_count = 0,
+ .dependencies = NULL,
+ .import_count = 0, // workaround for 0-length C struct
+ .imports = iree_io_parameters_module_imports_,
+ .export_count = IREE_ARRAYSIZE(iree_io_parameters_module_exports_),
+ .exports = iree_io_parameters_module_exports_,
+ .function_count = IREE_ARRAYSIZE(iree_io_parameters_module_funcs_),
+ .functions = iree_io_parameters_module_funcs_,
+};
+
+IREE_API_EXPORT iree_status_t iree_io_parameters_module_create(
+ iree_vm_instance_t* instance, iree_host_size_t provider_count,
+ iree_io_parameter_provider_t* const* providers,
+ iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(instance);
+ IREE_ASSERT_ARGUMENT(!provider_count || providers);
+ IREE_ASSERT_ARGUMENT(out_module);
+ *out_module = NULL;
+
+ // Setup the interface with the functions we implement ourselves. Any function
+ // we omit will be handled by the base native module.
+ static const iree_vm_module_t interface = {
+ .destroy = iree_io_parameters_module_destroy,
+ .alloc_state = iree_io_parameters_module_alloc_state,
+ .free_state = iree_io_parameters_module_free_state,
+ .notify = iree_io_parameters_module_notify,
+ };
+
+ // Allocate shared module state.
+ iree_host_size_t total_size =
+ iree_vm_native_module_size() + sizeof(iree_io_parameters_module_t) +
+ provider_count * sizeof(iree_io_parameter_provider_t*);
+ iree_vm_module_t* base_module = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
+ memset(base_module, 0, total_size);
+ iree_status_t status = iree_vm_native_module_initialize(
+ &interface, &iree_io_parameters_module_descriptor_, instance,
+ host_allocator, base_module);
+ if (!iree_status_is_ok(status)) {
+ iree_allocator_free(host_allocator, base_module);
+ return status;
+ }
+
+ iree_io_parameters_module_t* module =
+ IREE_IO_PARAMETERS_MODULE_CAST(base_module);
+ module->host_allocator = host_allocator;
+ module->provider_count = provider_count;
+ for (iree_host_size_t i = 0; i < provider_count; ++i) {
+ module->providers[i] = providers[i];
+ iree_io_parameter_provider_retain(providers[i]);
+ }
+
+ *out_module = base_module;
+ return iree_ok_status();
+}
diff --git a/runtime/src/iree/modules/io/parameters/module.h b/runtime/src/iree/modules/io/parameters/module.h
new file mode 100644
index 0000000..e066d45
--- /dev/null
+++ b/runtime/src/iree/modules/io/parameters/module.h
@@ -0,0 +1,31 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_MODULES_IO_PARAMETERS_MODULE_H_
+#define IREE_MODULES_IO_PARAMETERS_MODULE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/io/parameter_provider.h"
+#include "iree/vm/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a module for accessing parameters via a set of |providers|.
+// The providers are retained for the lifetime of the module.
+IREE_API_EXPORT iree_status_t iree_io_parameters_module_create(
+ iree_vm_instance_t* instance, iree_host_size_t provider_count,
+ iree_io_parameter_provider_t* const* providers,
+ iree_allocator_t host_allocator,
+ iree_vm_module_t** IREE_RESTRICT out_module);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_MODULES_IO_PARAMETERS_MODULE_H_
diff --git a/runtime/src/iree/tooling/BUILD.bazel b/runtime/src/iree/tooling/BUILD.bazel
index 51b7aa0..55d4d3c 100644
--- a/runtime/src/iree/tooling/BUILD.bazel
+++ b/runtime/src/iree/tooling/BUILD.bazel
@@ -75,6 +75,7 @@
hdrs = ["context_util.h"],
deps = [
":device_util",
+ ":parameter_util",
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:file_io",
"//runtime/src/iree/base/internal:flags",
@@ -146,6 +147,27 @@
)
iree_runtime_cc_library(
+ name = "parameter_util",
+ srcs = ["parameter_util.c"],
+ hdrs = ["parameter_util.h"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:file_io",
+ "//runtime/src/iree/base/internal:flags",
+ "//runtime/src/iree/base/internal:path",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/io:parameter_index",
+ "//runtime/src/iree/io:parameter_index_provider",
+ "//runtime/src/iree/io:parameter_provider",
+ "//runtime/src/iree/io:scope_map",
+ "//runtime/src/iree/io/formats/gguf",
+ "//runtime/src/iree/io/formats/safetensors",
+ "//runtime/src/iree/modules/io/parameters",
+ "//runtime/src/iree/vm",
+ ],
+)
+
+iree_runtime_cc_library(
name = "run_module",
srcs = ["run_module.c"],
hdrs = ["run_module.h"],
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 3d73a0c..a8d0711 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -84,6 +84,7 @@
"context_util.c"
DEPS
::device_util
+ ::parameter_util
iree::base
iree::base::internal::file_io
iree::base::internal::flags
@@ -167,6 +168,30 @@
iree_cc_library(
NAME
+ parameter_util
+ HDRS
+ "parameter_util.h"
+ SRCS
+ "parameter_util.c"
+ DEPS
+ iree::base
+ iree::base::internal::file_io
+ iree::base::internal::flags
+ iree::base::internal::path
+ iree::hal
+ iree::io::formats::gguf
+ iree::io::formats::safetensors
+ iree::io::parameter_index
+ iree::io::parameter_index_provider
+ iree::io::parameter_provider
+ iree::io::scope_map
+ iree::modules::io::parameters
+ iree::vm
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
run_module
HDRS
"run_module.h"
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index 6dcf47e..135a111 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -20,6 +20,7 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/modules/resolver.h"
+#include "iree/tooling/parameter_util.h"
#include "iree/vm/bytecode/module.h"
#include "iree/vm/dynamic/module.h"
@@ -424,6 +425,10 @@
} else if (iree_string_view_equal(dependency->name, IREE_SV("hal_loader"))) {
IREE_RETURN_IF_ERROR(iree_tooling_load_hal_loader_module(
state->instance, state->host_allocator, &module));
+ } else if (iree_string_view_equal(dependency->name,
+ IREE_SV("io_parameters"))) {
+ IREE_RETURN_IF_ERROR(iree_tooling_create_parameters_module_from_flags(
+ state->instance, state->host_allocator, &module));
} else {
// Defer to the generic module resolver registry.
IREE_RETURN_IF_ERROR(iree_tooling_resolve_module_dependency(
diff --git a/runtime/src/iree/tooling/parameter_util.c b/runtime/src/iree/tooling/parameter_util.c
new file mode 100644
index 0000000..0ad4a0e
--- /dev/null
+++ b/runtime/src/iree/tooling/parameter_util.c
@@ -0,0 +1,203 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/tooling/parameter_util.h"
+
+#include "iree/base/internal/file_io.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/internal/path.h"
+#include "iree/io/formats/gguf/gguf_format.h"
+#include "iree/io/formats/safetensors/safetensors_format.h"
+#include "iree/io/parameter_index.h"
+#include "iree/io/parameter_index_provider.h"
+#include "iree/io/scope_map.h"
+#include "iree/modules/io/parameters/module.h"
+
+//===----------------------------------------------------------------------===//
+// Parameter file I/O
+//===----------------------------------------------------------------------===//
+
+IREE_FLAG(
+ string, parameter_mode, "mmap",
+ "A parameter I/O mode of ['preload', 'mmap'].\n"
+ " preload: read entire parameter files into wired memory on startup.\n"
+ " mmap: maps the parameter files into discardable memory - can increase\n"
+ " warm-up time and variance as mapped pages are swapped\n"
+ " by the OS.");
+
+static void iree_file_contents_release_callback(
+ void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
+ iree_file_contents_t* file_contents = (iree_file_contents_t*)user_data;
+ iree_file_contents_free(file_contents);
+}
+
+// Opens the parameter file at |path| with the mode specified by the
+// --parameter_mode flag and returns its handle.
+static iree_status_t iree_io_open_parameter_file(
+ iree_string_view_t path, iree_allocator_t host_allocator,
+ iree_io_file_handle_t** out_file_handle) {
+ IREE_ASSERT_ARGUMENT(out_file_handle);
+ *out_file_handle = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size);
+
+ char path_str[2048] = {0};
+ iree_string_view_to_cstring(path, path_str, sizeof(path_str));
+ iree_file_read_flags_t read_flags = 0;
+ if (strcmp(FLAG_parameter_mode, "mmap") == 0) {
+ read_flags |= IREE_FILE_READ_FLAG_MMAP;
+ } else if (strcmp(FLAG_parameter_mode, "preload") == 0) {
+ read_flags |= IREE_FILE_READ_FLAG_PRELOAD;
+ } else {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unrecognized --parameter_mode= value '%s'",
+ FLAG_parameter_mode);
+ }
+
+ iree_file_contents_t* file_contents = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_file_read_contents(path_str, read_flags, host_allocator,
+ &file_contents));
+
+ iree_io_file_handle_release_callback_t release_callback = {
+ .fn = iree_file_contents_release_callback,
+ .user_data = file_contents,
+ };
+ iree_io_file_handle_t* file_handle = NULL;
+ iree_status_t status = iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_READ, file_contents->buffer, release_callback,
+ host_allocator, &file_handle);
+ if (iree_status_is_ok(status)) {
+ *out_file_handle = file_handle;
+ } else {
+ iree_file_contents_free(file_contents);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// Parameter file format parsing
+//===----------------------------------------------------------------------===//
+
+IREE_FLAG_LIST(
+ string, parameters,
+ "Specifies a parameter file to make available to programs with either an\n"
+ "anonymous global scope (`some_file.gguf`) or a named scope like\n"
+ "`my_scope=some_file.gguf`.\n"
+ "\n"
+ "Supported formats:\n"
+ "- .irpa (IREE parameter archive)\n"
+ "- .gguf (https://github.com/ggerganov/ggml/blob/master/docs/gguf.md)"
+ "- .safetensors (https://github.com/huggingface/safetensors)");
+
+// Appends the parameter file located at |path| to |index|.
+static iree_status_t iree_io_append_parameter_file_to_index(
+ iree_string_view_t path, iree_io_parameter_index_t* index,
+ iree_allocator_t host_allocator) {
+ IREE_ASSERT_ARGUMENT(index);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Open the file.
+ iree_io_file_handle_t* file_handle = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_open_parameter_file(path, host_allocator, &file_handle));
+
+ // Index the file based on its (inferred) format.
+ iree_status_t status = iree_ok_status();
+ iree_string_view_t path_ext = iree_file_path_extension(path);
+ if (iree_string_view_equal_case(path_ext, IREE_SV("gguf"))) {
+ status = iree_io_parse_gguf_index(file_handle, index);
+ } else if (iree_string_view_equal_case(path_ext, IREE_SV("safetensors"))) {
+ status = iree_io_parse_safetensors_index(file_handle, index);
+ } else {
+ status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unhandled parameter file format: .%.*s",
+ (int)path_ext.size, path_ext.data);
+ }
+
+ // Release our file reference - it's still retained by the index if it had any
+ // parameters in it.
+ iree_io_file_handle_release(file_handle);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_tooling_build_parameter_indices_from_flags(
+ iree_io_scope_map_t* scope_map) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Create one index per scope and add parameters to each.
+ for (iree_host_size_t i = 0; i < FLAG_parameters_list().count; ++i) {
+ // Parse the `scope=path` flag. Note that the scope is optional.
+ iree_string_view_t flag = FLAG_parameters_list().values[i];
+ iree_string_view_t scope, path;
+ if (iree_string_view_split(flag, '=', &scope, &path) == -1) {
+ // No scope provided (that's ok).
+ path = scope;
+ scope = iree_string_view_empty();
+ }
+
+ // Lookup (or create) the index for the given scope.
+ iree_io_parameter_index_t* index = NULL; // unowned
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_scope_map_lookup(scope_map, scope, &index));
+
+ // Index the file.
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_append_parameter_file_to_index(path, index,
+ scope_map->host_allocator));
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_tooling_create_parameters_module_from_flags(
+ iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_scope_map_t scope_map;
+ iree_io_scope_map_initialize(host_allocator, &scope_map);
+
+ // Parse all parameter files and build out their indices.
+ iree_status_t status =
+ iree_tooling_build_parameter_indices_from_flags(&scope_map);
+
+ // Create one provider per scope.
+ iree_host_size_t provider_count = 0;
+ iree_io_parameter_provider_t** providers =
+ (iree_io_parameter_provider_t**)iree_alloca(
+ scope_map.count * sizeof(iree_io_parameter_provider_t*));
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < scope_map.count; ++i) {
+ status = iree_io_parameter_index_provider_create(
+ scope_map.entries[i]->scope, scope_map.entries[i]->index,
+ IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS,
+ host_allocator, &providers[i]);
+ if (!iree_status_is_ok(status)) break;
+ ++provider_count;
+ }
+ }
+
+ // Create the module with the list of providers.
+ if (iree_status_is_ok(status)) {
+ status = iree_io_parameters_module_create(
+ instance, provider_count, providers, host_allocator, out_module);
+ }
+
+ // Cleanup (module owns providers which own indices/etc).
+ for (iree_host_size_t i = 0; i < provider_count; ++i) {
+ iree_io_parameter_provider_release(providers[i]);
+ }
+ iree_io_scope_map_deinitialize(&scope_map);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/tooling/parameter_util.h b/runtime/src/iree/tooling/parameter_util.h
new file mode 100644
index 0000000..5342b89
--- /dev/null
+++ b/runtime/src/iree/tooling/parameter_util.h
@@ -0,0 +1,32 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_TOOLING_PARAMETER_UTIL_H_
+#define IREE_TOOLING_PARAMETER_UTIL_H_
+
+#include "iree/base/api.h"
+#include "iree/vm/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct iree_io_scope_map_t iree_io_scope_map_t;
+
+// Populates |scope_map| with parameter indices as specified by flags.
+iree_status_t iree_tooling_build_parameter_indices_from_flags(
+ iree_io_scope_map_t* scope_map);
+
+// Builds an I/O parameters module based on the runtime flags provided.
+iree_status_t iree_tooling_create_parameters_module_from_flags(
+ iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_TOOLING_PARAMETER_UTIL_H_
diff --git a/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py b/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py
index 2d96ac9..5ac1b34 100644
--- a/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py
+++ b/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py
@@ -1,4 +1,3 @@
-# Lint as: python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index c7f18a9..5a4b2a7c 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -66,6 +66,9 @@
IREE_VM_ABI_DEFINE_SHIM(rrrIii, v);
IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r);
IREE_VM_ABI_DEFINE_SHIM(rIrrrIrIIi, v);
+IREE_VM_ABI_DEFINE_SHIM(rIrrrrrrr, v);
+IREE_VM_ABI_DEFINE_SHIM(rIrrrrIrII, v)
+IREE_VM_ABI_DEFINE_SHIM(rIrrrrIIiiI, r);
IREE_VM_ABI_DEFINE_SHIM(rIrrr, v);
IREE_VM_ABI_DEFINE_SHIM(rIrrCrD, v);
IREE_VM_ABI_DEFINE_SHIM(CrID, r);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 2a1e60c..95aa5e4 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -446,6 +446,45 @@
int32_t i9;
});
+IREE_VM_ABI_FIXED_STRUCT(rIrrrrrrr, {
+ iree_vm_ref_t r0;
+ int64_t i1;
+ iree_vm_ref_t r2;
+ iree_vm_ref_t r3;
+ iree_vm_ref_t r4;
+ iree_vm_ref_t r5;
+ iree_vm_ref_t r6;
+ iree_vm_ref_t r7;
+ iree_vm_ref_t r8;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rIrrrrIrII, {
+ iree_vm_ref_t r0;
+ int64_t i1;
+ iree_vm_ref_t r2;
+ iree_vm_ref_t r3;
+ iree_vm_ref_t r4;
+ iree_vm_ref_t r5;
+ int64_t i6;
+ iree_vm_ref_t r7;
+ int64_t i8;
+ int64_t i9;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rIrrrrIIiiI, {
+ iree_vm_ref_t r0;
+ int64_t i1;
+ iree_vm_ref_t r2;
+ iree_vm_ref_t r3;
+ iree_vm_ref_t r4;
+ iree_vm_ref_t r5;
+ int64_t i6;
+ int64_t i7;
+ int32_t i8;
+ int32_t i9;
+ int64_t i10;
+});
+
IREE_VM_ABI_FIXED_STRUCT(rIrrr, {
iree_vm_ref_t r0;
int64_t i1;
@@ -656,6 +695,9 @@
IREE_VM_ABI_DECLARE_SHIM(rrrIii, v);
IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r);
IREE_VM_ABI_DECLARE_SHIM(rIrrrIrIIi, v);
+IREE_VM_ABI_DECLARE_SHIM(rIrrrrrrr, v);
+IREE_VM_ABI_DECLARE_SHIM(rIrrrrIrII, v)
+IREE_VM_ABI_DECLARE_SHIM(rIrrrrIIiiI, r);
IREE_VM_ABI_DECLARE_SHIM(rIrrr, v);
IREE_VM_ABI_DECLARE_SHIM(rIrrCrD, v);
IREE_VM_ABI_DECLARE_SHIM(CrID, r);
diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel
index 11af40c..d6316c1 100644
--- a/tools/BUILD.bazel
+++ b/tools/BUILD.bazel
@@ -116,6 +116,19 @@
)
iree_runtime_cc_binary(
+ name = "iree-dump-parameters",
+ srcs = ["iree-dump-parameters-main.c"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:file_io",
+ "//runtime/src/iree/base/internal:flags",
+ "//runtime/src/iree/io:parameter_index",
+ "//runtime/src/iree/io:scope_map",
+ "//runtime/src/iree/tooling:parameter_util",
+ ],
+)
+
+iree_runtime_cc_binary(
name = "iree-fatelf",
srcs = ["iree-fatelf.c"],
deps = [
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 2445774..65608e3 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -150,6 +150,20 @@
iree::vm::bytecode::module
)
+iree_cc_binary(
+ NAME
+ iree-dump-parameters
+ SRCS
+ "iree-dump-parameters-main.c"
+ DEPS
+ iree::base
+ iree::base::internal::file_io
+ iree::base::internal::flags
+ iree::io::parameter_index
+ iree::io::scope_map
+ iree::tooling::parameter_util
+)
+
# Only enable fatelf tool when we're compiling it in.
# Currently it requires that the host and target both support embedded ELFs as
# the ELF implementation is only compiled when the target supports it.
diff --git a/tools/iree-dump-module-main.c b/tools/iree-dump-module-main.c
index e1f70a9..cc39827 100644
--- a/tools/iree-dump-module-main.c
+++ b/tools/iree-dump-module-main.c
@@ -550,7 +550,8 @@
// Parse command line flags.
iree_flags_set_usage("iree-dump-module",
- "Dumps IREE VM module details to stdout.\n");
+ "Dumps IREE VM module details to stdout.\n"
+ "$ iree-dump-module [--output=...] module.vmfb\n");
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc < 2) {
diff --git a/tools/iree-dump-parameters-main.c b/tools/iree-dump-parameters-main.c
new file mode 100644
index 0000000..eb08561
--- /dev/null
+++ b/tools/iree-dump-parameters-main.c
@@ -0,0 +1,196 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Dumps parsed parameter file information.
+// We intentionally use the same flags and parsing behavior as the rest of the
+// runtime tools so that users can pass their parameter flags and see exactly
+// what the runtime tools would see post-indexing. We don't try to dump original
+// file metadata as we only support parsing enough information for what we need.
+//
+// We also support basic extraction of individual parameters in cases where
+// users want to dump them out. Since we don't parse or preserve metadata we
+// can't easily dump them to
+//
+// # List all available parameters and their index information:
+// $ iree-dump-parameters --parameters=my_scope=my_file.gguf [--parameters=...]
+// # Extract parameter binary contents from a file:
+// $ iree-dump-parameters ... --extract=scope::key0=file0.bin [--extract=...]
+
+#include <ctype.h>
+#include <stdio.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/file_io.h"
+#include "iree/base/internal/flags.h"
+#include "iree/io/parameter_index.h"
+#include "iree/io/scope_map.h"
+#include "iree/tooling/parameter_util.h"
+
+//===----------------------------------------------------------------------===//
+// Parameter index information
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_tooling_dump_parameter_index(
+ iree_string_view_t scope, iree_io_parameter_index_t* index) {
+ iree_host_size_t entry_count = iree_io_parameter_index_count(index);
+ uint64_t total_bytes = 0;
+ for (iree_host_size_t i = 0; i < entry_count; ++i) {
+ const iree_io_parameter_index_entry_t* entry = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameter_index_get(index, i, &entry));
+ total_bytes += entry->length;
+ }
+ fprintf(stdout,
+ "//"
+ "===-----------------------------------------------------------------"
+ "---------------------------------------------===//\n");
+ fprintf(stdout,
+ "// Parameter scope `%.*s` (%" PRIhsz " entries, %" PRIu64
+ " total bytes)\n",
+ (int)scope.size, scope.data, entry_count, total_bytes);
+ fprintf(stdout,
+ "//===------------+------------------+------------------+------------"
+ "-----------------------------------------------===//\n");
+ fprintf(stdout,
+ "// Start | End | Length | Key\n");
+ fprintf(stdout,
+ "//---------------+------------------+------------------+------------"
+ "--------------------------------------------------//\n");
+ for (iree_host_size_t i = 0; i < entry_count; ++i) {
+ const iree_io_parameter_index_entry_t* entry = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameter_index_get(index, i, &entry));
+ fprintf(stdout, "%16" PRIu64 " | %16" PRIu64 " | %16" PRIu64 " | `%.*s`\n",
+ entry->offset, entry->offset + entry->length, entry->length,
+ (int)entry->key.size, entry->key.data);
+ }
+ fprintf(stdout, "\n");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_tooling_dump_scope_map(
+ iree_io_scope_map_t* scope_map) {
+ for (iree_host_size_t i = 0; i < scope_map->count; ++i) {
+ iree_string_view_t scope = scope_map->entries[i]->scope;
+ iree_io_parameter_index_t* index = scope_map->entries[i]->index;
+ IREE_RETURN_IF_ERROR(iree_tooling_dump_parameter_index(scope, index));
+ }
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Parameter extraction
+//===----------------------------------------------------------------------===//
+
+IREE_FLAG_LIST(string, extract,
+ "Extracts a parameter to a file as `[scope::]key=file.bin`.");
+
+static iree_status_t iree_tooling_extract_parameter(
+ iree_io_scope_map_t* scope_map, iree_string_view_t scope,
+ iree_string_view_t key, iree_string_view_t path,
+ iree_allocator_t host_allocator) {
+ // Lookup the index for the given scope.
+ iree_io_parameter_index_t* index = NULL; // unowned
+ IREE_RETURN_IF_ERROR(iree_io_scope_map_lookup(scope_map, scope, &index));
+
+ // Lookup the entry within the index.
+ const iree_io_parameter_index_entry_t* entry = NULL; // unowned
+ IREE_RETURN_IF_ERROR(iree_io_parameter_index_lookup(index, key, &entry));
+
+ fprintf(stdout, "Extracting parameter `");
+ if (!iree_string_view_is_empty(scope)) {
+ fprintf(stdout, "%.*s::", (int)scope.size, scope.data);
+ }
+ fprintf(stdout, "%.*s` (%" PRIu64 "b) to `%.*s`...\n", (int)key.size,
+ key.data, entry->length, (int)path.size, path.data);
+
+ // TODO(benvanik): support generic file handle IO instead of memory-only.
+ if (iree_io_file_handle_type(entry->file_handle) !=
+ IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "only host allocation file handles are supported today");
+ }
+ iree_byte_span_t file_contents =
+ iree_io_file_handle_value(entry->file_handle).host_allocation;
+ iree_const_byte_span_t entry_contents = iree_make_const_byte_span(
+ file_contents.data + entry->offset, entry->length);
+ char* path_str = (char*)iree_alloca(path.size + 1);
+ memcpy(path_str, path.data, path.size);
+ path_str[path.size] = 0;
+ return iree_file_write_contents(path_str, entry_contents);
+}
+
+static iree_status_t iree_tooling_extract_parameters(
+ iree_io_scope_map_t* scope_map, iree_allocator_t host_allocator) {
+ for (iree_host_size_t i = 0; i < FLAG_extract_list().count; ++i) {
+ iree_string_view_t flag = FLAG_extract_list().values[i];
+ iree_string_view_t identifier, path;
+ iree_string_view_split(flag, '=', &identifier, &path);
+
+ iree_host_size_t separator_pos =
+ iree_string_view_find_first_of(identifier, IREE_SV("::"), 0);
+ iree_string_view_t scope = iree_string_view_empty();
+ iree_string_view_t key = iree_string_view_empty();
+ if (separator_pos != IREE_STRING_VIEW_NPOS) {
+ scope = iree_string_view_substr(identifier, 0, separator_pos);
+ key = iree_string_view_substr(identifier, separator_pos + 2,
+ IREE_HOST_SIZE_MAX);
+ } else {
+ key = identifier;
+ }
+
+ IREE_RETURN_IF_ERROR(iree_tooling_extract_parameter(scope_map, scope, key,
+ path, host_allocator),
+ "extracting parameter with flag `%.*s`",
+ (int)flag.size, flag.data);
+ }
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// main
+//===----------------------------------------------------------------------===//
+
+int main(int argc, char** argv) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_t host_allocator = iree_allocator_system();
+ int exit_code = EXIT_SUCCESS;
+
+ // Parse command line flags.
+ iree_flags_set_usage("iree-dump-parameters",
+ "Dumps information about parsed parameter files.\n");
+ iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
+
+ iree_io_scope_map_t scope_map = {0};
+ iree_io_scope_map_initialize(host_allocator, &scope_map);
+
+ // Parse parameters using the common tooling flags.
+ iree_status_t status =
+ iree_tooling_build_parameter_indices_from_flags(&scope_map);
+
+ // Dump parameter information.
+ if (iree_status_is_ok(status)) {
+ status = iree_tooling_dump_scope_map(&scope_map);
+ }
+
+ // Extract parameters as requested, if any.
+ if (iree_status_is_ok(status)) {
+ status = iree_tooling_extract_parameters(&scope_map, host_allocator);
+ }
+
+ iree_io_scope_map_deinitialize(&scope_map);
+
+ fflush(stdout);
+ if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ iree_status_free(status);
+ exit_code = EXIT_FAILURE;
+ }
+ fflush(stderr);
+
+ IREE_TRACE_ZONE_END(z0);
+ return exit_code;
+}
diff --git a/tools/test/BUILD.bazel b/tools/test/BUILD.bazel
index 37fe960..6eea2e3 100644
--- a/tools/test/BUILD.bazel
+++ b/tools/test/BUILD.bazel
@@ -18,29 +18,38 @@
name = "lit",
srcs = enforce_glob(
[
+ "benchmark_flags.txt",
"compile_pipelines.mlir",
"compile_to_continuation.mlir",
"compile_to_phase.mlir",
"executable_benchmarks.mlir",
"executable_sources.mlir",
"iree-benchmark-module.mlir",
+ "iree-dump-parameters.txt",
"iree-run-mlir.mlir",
- "iree-run-module.mlir",
"iree-run-module-expected.mlir",
"iree-run-module-outputs.mlir",
+ "iree-run-module.mlir",
"iree-run-trace.mlir",
"multiple_args.mlir",
"multiple_exported_functions.mlir",
"null_values.mlir",
+ "parameters_scoped.mlir",
+ "parameters_unscoped.mlir",
"repeated_return.mlir",
"scalars.mlir",
],
- include = ["*.mlir"],
+ include = [
+ "*.mlir",
+ "iree-dump-parameters.txt",
+ ],
),
cfg = "//tools:lit.cfg.py",
data = [
"echo_npy.py",
"iree-run-trace.yml",
+ "parameters_a.safetensors",
+ "parameters_b.safetensors",
],
tags = [
"driver=local-task",
@@ -60,17 +69,3 @@
"@llvm-project//llvm:not",
],
)
-
-iree_lit_test_suite(
- name = "benchmark_flags",
- srcs = ["benchmark_flags.txt"],
- cfg = "//tools:lit.cfg.py",
- tags = [
- "hostonly",
- ],
- tools = [
- "//tools:iree-benchmark-module",
- "//tools:iree-compile",
- "@llvm-project//llvm:FileCheck",
- ],
-)
diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt
index faa6e3a..20973c7 100644
--- a/tools/test/CMakeLists.txt
+++ b/tools/test/CMakeLists.txt
@@ -14,12 +14,14 @@
NAME
lit
SRCS
+ "benchmark_flags.txt"
"compile_pipelines.mlir"
"compile_to_continuation.mlir"
"compile_to_phase.mlir"
"executable_benchmarks.mlir"
"executable_sources.mlir"
"iree-benchmark-module.mlir"
+ "iree-dump-parameters.txt"
"iree-run-mlir.mlir"
"iree-run-module-expected.mlir"
"iree-run-module-outputs.mlir"
@@ -28,6 +30,8 @@
"multiple_args.mlir"
"multiple_exported_functions.mlir"
"null_values.mlir"
+ "parameters_scoped.mlir"
+ "parameters_unscoped.mlir"
"repeated_return.mlir"
"scalars.mlir"
TOOLS
@@ -44,25 +48,14 @@
DATA
echo_npy.py
iree-run-trace.yml
+ parameters_a.safetensors
+ parameters_b.safetensors
LABELS
"driver=local-task"
"driver=vulkan"
"hostonly"
)
-iree_lit_test_suite(
- NAME
- benchmark_flags
- SRCS
- "benchmark_flags.txt"
- TOOLS
- FileCheck
- iree-benchmark-module
- iree-compile
- LABELS
- "hostonly"
-)
-
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
if(IREE_BUILD_TESTS AND
diff --git a/tools/test/iree-dump-parameters.txt b/tools/test/iree-dump-parameters.txt
new file mode 100644
index 0000000..77b4294
--- /dev/null
+++ b/tools/test/iree-dump-parameters.txt
@@ -0,0 +1,12 @@
+// RUN: (iree-dump-parameters \
+// RUN: --parameters=a=%p/parameters_a.safetensors \
+// RUN: --parameters=b=%p/parameters_b.safetensors) | \
+// RUN: FileCheck %s
+
+// CHECK: Parameter scope `a`
+// CHECK: 120 | 152 | 32 | `a0`
+// CHECK: 152 | 184 | 32 | `a1`
+
+// CHECK: Parameter scope `b`
+// CHECK: 128 | 192 | 64 | `b0`
+// CHECK: 192 | 320 | 128 | `b1`
diff --git a/tools/test/parameters.py b/tools/test/parameters.py
new file mode 100644
index 0000000..fa4eeec
--- /dev/null
+++ b/tools/test/parameters.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# https://huggingface.co/docs/safetensors/index
+#
+# To regenerate:
+# $ pip install torch safetensors
+# $ cd tools/test/
+# $ ./parameters.py
+
+import torch
+from safetensors.torch import save_file
+
+save_file(
+ {
+ "a0": torch.arange(0, 4),
+ "a1": torch.arange(4, 8),
+ },
+ "parameters_a.safetensors",
+)
+
+save_file(
+ {
+ "b0": torch.arange(8, 16),
+ "b1": torch.arange(16, 32),
+ },
+ "parameters_b.safetensors",
+)
diff --git a/tools/test/parameters_a.safetensors b/tools/test/parameters_a.safetensors
new file mode 100644
index 0000000..1d92fe8
--- /dev/null
+++ b/tools/test/parameters_a.safetensors
Binary files differ
diff --git a/tools/test/parameters_b.safetensors b/tools/test/parameters_b.safetensors
new file mode 100644
index 0000000..348d3c4
--- /dev/null
+++ b/tools/test/parameters_b.safetensors
Binary files differ
diff --git a/tools/test/parameters_scoped.mlir b/tools/test/parameters_scoped.mlir
new file mode 100644
index 0000000..2bc7c9e
--- /dev/null
+++ b/tools/test/parameters_scoped.mlir
@@ -0,0 +1,27 @@
+// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-run-module --device=local-sync --module=- --function=echo \
+// RUN: --parameters=a=%p/parameters_a.safetensors \
+// RUN: --parameters=b=%p/parameters_b.safetensors \
+// RUN: --expected_output=4xi64=0,1,2,3 \
+// RUN: --expected_output=4xi64=4,5,6,7 \
+// RUN: --expected_output=8xi64=8,9,10,11,12,13,14,15 \
+// RUN: --expected_output=8xi64=16,17,18,19,20,21,22,23) | \
+// RUN: FileCheck %s
+// CHECK: [SUCCESS]
+
+// Parameters scoped to allow for separating parameters from multiple models or
+// model stages in a compiled pipeline. It's possible to have multiple files
+// provide content for a single scope but not to have a single file provide
+// content for multiple scopes. Since parameter keys only need to be unique
+// within a scope this test could use the same name for both scopes if needed.
+util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64>
+util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64>
+util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64>
+util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64>
+func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
+ %a0 = util.global.load @a0 : tensor<4xi64>
+ %a1 = util.global.load @a1 : tensor<4xi64>
+ %b0 = util.global.load @b0 : tensor<8xi64>
+ %b1 = util.global.load @b1 : tensor<8xi64>
+ return %a0, %a1, %b0, %b1 : tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>
+}
diff --git a/tools/test/parameters_unscoped.mlir b/tools/test/parameters_unscoped.mlir
new file mode 100644
index 0000000..933cb77
--- /dev/null
+++ b/tools/test/parameters_unscoped.mlir
@@ -0,0 +1,25 @@
+// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-run-module --device=local-sync --module=- --function=echo \
+// RUN: --parameters=%p/parameters_a.safetensors \
+// RUN: --parameters=%p/parameters_b.safetensors \
+// RUN: --expected_output=4xi64=0,1,2,3 \
+// RUN: --expected_output=4xi64=4,5,6,7 \
+// RUN: --expected_output=8xi64=8,9,10,11,12,13,14,15 \
+// RUN: --expected_output=8xi64=16,17,18,19,20,21,22,23) | \
+// RUN: FileCheck %s
+// CHECK: [SUCCESS]
+
+// Simple named parameters with no scope. Parameter files are combined at
+// runtime to allow for filesystem sharding while still providing a flat set of
+// parameters in the compiler input.
+util.global private @a0 = #stream.parameter.named<"a0"> : tensor<4xi64>
+util.global private @a1 = #stream.parameter.named<"a1"> : tensor<4xi64>
+util.global private @b0 = #stream.parameter.named<"b0"> : tensor<8xi64>
+util.global private @b1 = #stream.parameter.named<"b1"> : tensor<8xi64>
+func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
+ %a0 = util.global.load @a0 : tensor<4xi64>
+ %a1 = util.global.load @a1 : tensor<4xi64>
+ %b0 = util.global.load @b0 : tensor<8xi64>
+ %b1 = util.global.load @b1 : tensor<8xi64>
+ return %a0, %a1, %b0, %b1 : tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>
+}