Batching parameter load operations and cleaning up gather/scatter. (#15706)
This makes loads look like gathers/scatters and allows us to move the
(relatively) tricky concurrency scheduling logic to the runtime. A
single load operation can now return any number of parameters with
unique storage buffers (hopefully imported/zero-copy) so long as they
have matching buffer parameters (of which all in general do). The core
logic for scheduling the batched operations has been shared such that
load/gather/scatter are all going down the same path meaning that as we
add new parameter types and optimize scheduling we only have one code
path to tweak. Some minor optimizations have been done to elide batch
overhead but many have been deferred as compared to staging even 10MB of
parameters the current profile is in the noise. The standalone
read/write methods were removed to simplify the compiler<->runtime
interface and implementations of `iree_io_parameter_provider_t` -
currently the only overhead incurred is an additional queue join barrier
that we can optimize away in the future in most cases.
Since load/gather use the same code path now we shouldn't have
correctness issues unique to any particular path and can turn back on
the gather path which has much less overhead in the compiler/vmfb that
otherwise needs to handle independent buffers per parameter. We can
eventually optimize the load path to batch device buffer allocations but
the compiler/vmfb still needs to treat each as independent and we won't
get savings there. The rule is that the unified memory model should only
be used when building a vmfb that targets devices that can do zero-copy
loads from memory mapped files - every other case should use discrete.
Progress on #15521.
Progress on #15522.
Works around several issues in #15674.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 7a760d1..0c22c10 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -993,43 +993,51 @@
namespace {
-struct FoldParameterLoadTargetSubview
+struct FoldParameterLoadTargetSubviews
: public OpRewritePattern<ParameterLoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ParameterLoadOp op,
PatternRewriter &rewriter) const override {
- auto loadResult = op.getResult();
- if (!loadResult.hasOneUse())
- return failure();
- Operation *user = *loadResult.getUsers().begin();
-
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
bool needsUpdate = false;
- auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
- auto newResultSize = op.getResultSize();
- if (auto subviewOp = dyn_cast<IREE::Stream::ResourceSubviewOp>(user)) {
- auto viewSourceOffset = subviewOp.getSourceOffset();
- auto viewResultSize = subviewOp.getResultSize();
- if (IREE::Util::tryMoveProducerBefore(viewSourceOffset, op) &&
- IREE::Util::tryMoveProducerBefore(viewResultSize, op)) {
- newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
- subviewOp.getLoc(), newSourceOffset,
- rewriter.createOrFold<mlir::arith::IndexCastOp>(
- subviewOp.getLoc(), rewriter.getI64Type(), viewSourceOffset));
- newResultSize = viewResultSize;
- rewriter.replaceAllUsesWith(subviewOp.getResult(), op.getResult());
- needsUpdate = true;
+ SmallVector<Value> newSourceOffsets;
+ SmallVector<Value> newResultSizes;
+ size_t resultCount = op.getResults().size();
+ newSourceOffsets.reserve(resultCount);
+ newResultSizes.reserve(resultCount);
+
+ for (auto [loadResult, newSourceOffset, newResultSize] : llvm::zip_equal(
+ op.getResults(), op.getSourceOffsets(), op.getResultSizes())) {
+ if (loadResult.hasOneUse()) {
+ Operation *user = *loadResult.getUsers().begin();
+ if (auto subviewOp = dyn_cast<IREE::Stream::ResourceSubviewOp>(user)) {
+ auto viewSourceOffset = subviewOp.getSourceOffset();
+ auto viewResultSize = subviewOp.getResultSize();
+ if (IREE::Util::tryMoveProducerBefore(viewSourceOffset, op) &&
+ IREE::Util::tryMoveProducerBefore(viewResultSize, op)) {
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subviewOp.getLoc(), newSourceOffset,
+ rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ subviewOp.getLoc(), rewriter.getI64Type(),
+ viewSourceOffset));
+ newResultSize = viewResultSize;
+ rewriter.replaceAllUsesWith(subviewOp.getResult(), loadResult);
+ needsUpdate = true;
+ }
+ }
}
+ newSourceOffsets.push_back(newSourceOffset);
+ newResultSizes.push_back(newResultSize);
}
rewriter.restoreInsertionPoint(ip);
if (!needsUpdate)
return failure();
rewriter.updateRootInPlace(op, [&]() {
- op.getSourceOffsetMutable().assign(newSourceOffset);
- op.getResultSizeMutable().assign(newResultSize);
+ op.getSourceOffsetsMutable().assign(newSourceOffsets);
+ op.getResultSizesMutable().assign(newResultSizes);
});
return success();
}
@@ -1040,7 +1048,7 @@
void ParameterLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ElideUnusedOp<ParameterLoadOp>>(context);
- results.insert<FoldParameterLoadTargetSubview>(context);
+ results.insert<FoldParameterLoadTargetSubviews>(context);
results.insert<ElideImmediateTimepointWait<ParameterLoadOp>>(context);
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 09b6847..b67f6d1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -372,6 +372,81 @@
}
//===----------------------------------------------------------------------===//
+// custom<ParameterLoadOperations>(
+// $source_scope, $source_keys, $source_offsets,
+// type($results), $result_sizes)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterLoadOperations(
+ OpAsmParser &parser, StringAttr &sourceScopeAttr, ArrayAttr &sourceKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultSizes) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> sourceKeyAttrs;
+ do {
+ StringAttr rowSourceScopeAttr;
+ StringAttr sourceKeyAttr;
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ Type resultType;
+ OpAsmParser::UnresolvedOperand resultSize;
+ if (failed(parseParameterReference(parser, rowSourceScopeAttr,
+ sourceKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.parseLBrace()) ||
+ failed(parser.parseOperand(resultSize)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ if (!sourceScopeAttr) {
+ sourceScopeAttr = rowSourceScopeAttr;
+ } else if (rowSourceScopeAttr != sourceScopeAttr) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "each operation must use the same scope");
+ }
+ sourceKeyAttrs.push_back(sourceKeyAttr);
+ sourceOffsets.push_back(sourceOffset);
+ resultTypes.push_back(resultType);
+ resultSizes.push_back(resultSize);
+ } while (succeeded(parser.parseOptionalComma()));
+ sourceKeysAttr = builder.getArrayAttr(sourceKeyAttrs);
+ return success();
+}
+
+static void printParameterLoadOperations(OpAsmPrinter &p, Operation *op,
+ StringAttr sourceScopeAttr,
+ ArrayAttr sourceKeysAttr,
+ ValueRange sourceOffsets,
+ TypeRange resultTypes,
+ ValueRange resultSizes) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceKeysAttr.getAsRange<StringAttr>(), sourceOffsets,
+ resultTypes, resultSizes),
+ [&](std::tuple<StringAttr, Value, Type, Value> it) {
+ auto [sourceKeyAttr, sourceOffset, resultType, resultSize] = it;
+ printParameterReference(p, op, sourceScopeAttr, sourceKeyAttr);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << "] : ";
+ p.printType(resultType);
+ p << "{";
+ p.printOperand(resultSize);
+ p << "}";
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// custom<ParameterGatherOperations>(
// $source_scope, $source_keys, $source_offsets,
// $target, type($target), $target_size, $target_offsets, $target_lengths)
@@ -1184,6 +1259,17 @@
// stream.parameter.load
//===----------------------------------------------------------------------===//
+LogicalResult ParameterLoadOp::verify() {
+ ParameterLoadOp op = *this;
+ size_t expectedCount = op.getSourceKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getResultSizes().size() != expectedCount) {
+ return op.emitOpError() << "requires that the source keys, source offsets, "
+ "and result sizes are all 1:1";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// stream.parameter.read
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 1389eb2..07be713 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -644,6 +644,8 @@
let opDocGroup = OpGroupParameterOps in {
def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [
+ AttrSizedOperandSegments,
+ AllTypesMatch<["results"]>,
DeclareOpInterfaceMethods<Stream_AffinityOp, [
"getAffinity",
"setAffinity",
@@ -652,34 +654,36 @@
Stream_TimelineOp,
Util_SizeAwareOp,
]> {
- let summary = [{reads a resource from a parameter scope}];
+ let summary = [{reads one or more resources from a parameter scope}];
let description = [{
- Asynchronously reads a resource from an external parameter provider and
- returns the resulting stream resource. Depending on the resource type this
- may alias existing cached storage or be directly mapped to the parameter
- origin or result in a copy as if `stream.resource.alloca` and
- `stream.parameter.read` had been used.
+ Asynchronously reads one or more resources from an external parameter
+ provider and returns the resulting stream resources. Depending on the
+ resource type this may alias existing cached storage or be directly mapped
+ to the parameter origin or result in a copy as if `stream.resource.alloca`
+ and `stream.parameter.read` had been used per parameter.
}];
let arguments = (ins
OptionalAttr<StrAttr>:$source_scope,
- StrAttr:$source_key,
- I64:$source_offset,
- Stream_Size:$result_size,
+ StrArrayAttr:$source_keys,
+ Variadic<I64>:$source_offsets,
+ Variadic<Stream_Size>:$result_sizes,
Optional<Stream_Timepoint>:$await_timepoint,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
- Stream_AnyStreamResource:$result,
+ Variadic<Stream_AnyStreamResource>:$results,
Stream_Timepoint:$result_timepoint
);
let assemblyFormat = [{
(`on` `(` $affinity^ `)`)?
(`await` `(` $await_timepoint^ `)` `=` `` `>`)?
- custom<ParameterReference>($source_scope, $source_key)
- `` `[` $source_offset `]` `:`
- type($result) `` `{` $result_size `}`
+ `{`
+ custom<ParameterLoadOperations>(
+ $source_scope, $source_keys, $source_offsets,
+ type($results), $result_sizes)
+ `}`
`=` `` `>`
type($result_timepoint)
attr-dict-with-keyword
@@ -687,12 +691,14 @@
let extraClassDeclaration = [{
Value getOperandSize(unsigned idx) { return {}; }
- Value getResultSize(unsigned idx) { return getResultSize(); }
+ Value getResultSize(unsigned idx) { return getResultSizes()[idx]; }
SmallVector<Value> getAwaitTimepoints() {
if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {};
}
}];
+ let hasVerifier = 1;
+
let hasCanonicalizer = 1;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir
index ddeb54d..cb9bf80 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_folding.mlir
@@ -1,18 +1,30 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
// CHECK-LABEL: @FoldParameterLoadTargetSubview
-// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index)
-func.func @FoldParameterLoadTargetSubview(%wait: !stream.timepoint, %offset: index, %length: index) -> (!stream.resource<constant>, !stream.timepoint) {
+// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[OFFSET0:.+]]: index, %[[LENGTH0:.+]]: index, %[[OFFSET1:.+]]: index, %[[LENGTH1:.+]]: index)
+func.func @FoldParameterLoadTargetSubview(%wait: !stream.timepoint, %offset0: index, %length0: index, %offset1: index, %length1: index) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[OFFSET_I64:.+]] = arith.index_cast %[[OFFSET]] : index to i64
- // CHECK-DAG: %[[PARAMETER_OFFSET:.+]] = arith.addi %[[OFFSET_I64]], %c50_i64
- // CHECK: %[[RESULT:.+]] = stream.parameter.load await(%[[WAIT]]) => "scope"::"key"[%[[PARAMETER_OFFSET]]] : !stream.resource<constant>{%[[LENGTH]]} => !stream.timepoint
- %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ %c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[OFFSET0_I64:.+]] = arith.index_cast %[[OFFSET0]] : index to i64
+ // CHECK-DAG: %[[PARAMETER_OFFSET0:.+]] = arith.addi %[[OFFSET0_I64]], %c50_i64
+ // CHECK-DAG: %[[OFFSET1_I64:.+]] = arith.index_cast %[[OFFSET1]] : index to i64
+ // CHECK-DAG: %[[PARAMETER_OFFSET1:.+]] = arith.addi %[[OFFSET1_I64]], %c51_i64
+ // CHECK: %[[RESULTS:.+]]:2, %[[SIGNAL:.+]] = stream.parameter.load await(%[[WAIT]]) => {
+ // CHECK-NEXT: "scope"::"key0"[%[[PARAMETER_OFFSET0]]] : !stream.resource<constant>{%[[LENGTH0]]},
+ // CHECK-NEXT: "scope"::"key1"[%[[PARAMETER_OFFSET1]]] : !stream.resource<constant>{%[[LENGTH1]]}
+ // CHECK-NEXT: } => !stream.timepoint
+ %results:2, %result_timepoint = stream.parameter.load await(%wait) => {
+ "scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
+ "scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c200}
+ } => !stream.timepoint
// CHECK-NOT: stream.resource.subview
- %subview = stream.resource.subview %result[%offset] : !stream.resource<constant>{%c100} -> !stream.resource<constant>{%length}
- // CHECK: return %[[RESULT]]
- return %subview, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+ %subview0 = stream.resource.subview %results#0[%offset0] : !stream.resource<constant>{%c100} -> !stream.resource<constant>{%length0}
+ // CHECK-NOT: stream.resource.subview
+ %subview1 = stream.resource.subview %results#1[%offset1] : !stream.resource<constant>{%c200} -> !stream.resource<constant>{%length1}
+ // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1, %[[SIGNAL]]
+ return %subview0, %subview1, %result_timepoint : !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
}
// -----
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir
index 0e6ba69..eff0bb0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/parameter_ops.mlir
@@ -11,12 +11,20 @@
// CHECK-LABEL: @parameterLoad
// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint)
-func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
- // CHECK: = stream.parameter.load await(%[[WAIT]]) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
- %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
- return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+ %c200 = arith.constant 200 : index
+ // CHECK: = stream.parameter.load await(%[[WAIT]]) => {
+ // CHECK-NEXT: "scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
+ // CHECK-NEXT: "scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c200}
+ // CHECK-NEXT: } => !stream.timepoint
+ %results:2, %result_timepoint = stream.parameter.load await(%wait) => {
+ "scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
+ "scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c200}
+ } => !stream.timepoint
+ return %results#0, %results#1, %result_timepoint : !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
}
// -----
@@ -26,8 +34,12 @@
func.func @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
%c100 = arith.constant 100 : index
- // CHECK: = stream.parameter.load await(%[[WAIT]]) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
- %result, %result_timepoint = stream.parameter.load await(%wait) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ // CHECK: = stream.parameter.load await(%[[WAIT]]) => {
+ // CHECK-NEXT: "key"[%c50_i64] : !stream.resource<constant>{%c100}
+ // CHECK-NEXT: } => !stream.timepoint
+ %result, %result_timepoint = stream.parameter.load await(%wait) => {
+ "key"[%c50_i64] : !stream.resource<constant>{%c100}
+ } => !stream.timepoint
return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index de82242..5d4d3a1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -275,26 +275,60 @@
return ParameterSlice{parameterAttr, sourceOffset, sourceLength};
}
-static TimepointResource
-buildParameterLoad(Location loc, Value awaitTimepoint,
- IREE::Stream::AffinityAttr affinityAttr, Type targetType,
- Value targetSize, const PackedSpan &packedSpan,
- IndexSet &indexSet, OpBuilder &builder) {
- auto parameterSlice =
- getParameterSlice(loc, packedSpan.slice.value, indexSet, builder);
+static Value buildParameterLoad(Value awaitTimepoint,
+ IREE::Stream::AffinityAttr affinityAttr,
+ Type targetType, StringAttr scope,
+ ArrayRef<StorageResource *> storageResources,
+ IndexSet &indexSet, OpBuilder &builder) {
+ SmallVector<Location> spanLocs;
+ SmallVector<Attribute> sourceKeys;
+ SmallVector<Value> sourceOffsets;
+ SmallVector<Type> targetTypes;
+ SmallVector<Value> targetLengths;
+ for (auto *storageResource : storageResources) {
+ assert(storageResource->spans.size() == 1 &&
+ "expected single span per resource for load");
+ for (auto &packedSpan : storageResource->spans) {
+ auto spanLoc = packedSpan.slice.result.getLoc();
+ auto parameterSlice =
+ getParameterSlice(spanLoc, packedSpan.slice.value, indexSet, builder);
+ spanLocs.push_back(spanLoc);
+ sourceKeys.push_back(parameterSlice.parameterAttr.getKey());
+ sourceOffsets.push_back(parameterSlice.sourceOffset);
+ targetTypes.push_back(targetType);
+ targetLengths.push_back(indexSet.get(packedSpan.length));
+ }
+ }
+
+ // Load all in a batch. One resource is returned per parameter but they may
+ // alias depending on the runtime implementation.
auto loadOp = builder.create<IREE::Stream::ParameterLoadOp>(
- loc, targetType, builder.getType<IREE::Stream::TimepointType>(),
- parameterSlice.parameterAttr.getScope(),
- parameterSlice.parameterAttr.getKey(), parameterSlice.sourceOffset,
- parameterSlice.sourceLength, awaitTimepoint, affinityAttr);
- return TimepointResource{loadOp.getResultTimepoint(), loadOp.getResult(),
- loadOp.getResultSize()};
+ builder.getFusedLoc(spanLocs), targetTypes,
+ builder.getType<IREE::Stream::TimepointType>(), scope,
+ builder.getArrayAttr(sourceKeys), sourceOffsets, targetLengths,
+ awaitTimepoint, affinityAttr);
+
+ // Slice out each span from the allocation.
+ // Note that access must be guarded by the final ready timepoint.
+ unsigned resultIndex = 0;
+ for (auto *storageResource : storageResources) {
+ for (auto &packedSpan : storageResource->spans) {
+ auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
+ packedSpan.slice.result.getLoc(), loadOp.getResult(resultIndex),
+ loadOp.getResultSize(resultIndex), indexSet.get(packedSpan.offset),
+ packedSpan.slice.resultSize);
+ packedSpan.slice.result.replaceAllUsesWith(subviewOp.getResult());
+ ++resultIndex;
+ }
+ }
+
+ return loadOp.getResultTimepoint();
}
static TimepointResource
buildParameterGather(Location loc, Value awaitTimepoint,
IREE::Stream::AffinityAttr affinityAttr, Type targetType,
- Value targetSize, ArrayRef<PackedSpan> packedSpans,
+ Value targetSize, MutableArrayRef<PackedSpan> packedSpans,
IndexSet &indexSet, OpBuilder &builder) {
// Allocate the resulting storage resource of the final resource type.
auto allocOp = builder.create<IREE::Stream::ResourceAllocOp>(
@@ -334,6 +368,16 @@
gatherTimepoints.push_back(gatherOp.getResultTimepoint());
}
+ // Slice out each span from the allocation.
+ // Note that access must be guarded by the final ready timepoint.
+ for (auto &packedSpan : packedSpans) {
+ auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
+ packedSpan.slice.result.getLoc(), allocOp.getResult(),
+ allocOp.getResultSize(0), indexSet.get(packedSpan.offset),
+ packedSpan.slice.resultSize);
+ packedSpan.slice.result.replaceAllUsesWith(subviewOp.getResult());
+ }
+
// Wait until all gathers have completed.
Value readyTimepoint =
IREE::Stream::TimepointJoinOp::join(gatherTimepoints, builder);
@@ -507,33 +551,40 @@
if (storageResources.empty())
return nullptr;
- // Emit the parameter loads or gathers for each unique resource.
- SmallVector<Value> uploadTimepoints;
+ // Sort resources by type so we can batch them.
+ // Loads are only possible if we are using the parameter as a constant and
+ // it is a single span as we can't pack externally owned parameters.
+ // A batch of loads happens from a single source scope so we bucket here.
+ // Note that we do this separate from the walk above as we may pack parameters
+ // such that they have a single parameter per resource and introduce more that
+ // we can load than if just looking at the original pre-packed state.
+ llvm::MapVector<StringAttr, SmallVector<StorageResource *>> resourceLoads;
+ SmallVector<StorageResource *> resourceGathers;
for (auto &storageResource : storageResources) {
- // Parameter-backed resource that we can either load or gather.
- // Loads are only possible if we are using the parameter as a constant and
- // it is a single span as we can't pack externally owned parameters.
- TimepointResource uploadedResource;
- auto resourceSize = indexSet.get(storageResource.totalSize);
- if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant &&
- storageResource.spans.size() == 1) {
- uploadedResource = buildParameterLoad(
- storageResource.loc, awaitTimepoint, affinityAttr, resourceType,
- resourceSize, storageResource.spans.front(), indexSet, builder);
+ if (storageResource.spans.size() == 1) {
+ auto parameterAttr = cast<IREE::Stream::NamedParameterAttr>(
+ storageResource.spans.front().slice.value);
+ resourceLoads[parameterAttr.getScope()].push_back(&storageResource);
} else {
- uploadedResource = buildParameterGather(
- storageResource.loc, awaitTimepoint, affinityAttr, resourceType,
- resourceSize, storageResource.spans, indexSet, builder);
+ resourceGathers.push_back(&storageResource);
}
+ }
- for (auto &span : storageResource.spans) {
- auto loc = span.slice.result.getLoc();
- auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
- loc, uploadedResource.resource, uploadedResource.resourceSize,
- indexSet.get(span.offset), span.slice.resultSize);
- span.slice.result.replaceAllUsesWith(subviewOp.getResult());
- }
+ // Emit all loads as a single operation per scope.
+ SmallVector<Value> uploadTimepoints;
+ for (auto &[scope, scopeResources] : resourceLoads) {
+ uploadTimepoints.push_back(
+ buildParameterLoad(awaitTimepoint, affinityAttr, resourceType, scope,
+ scopeResources, indexSet, builder));
+ }
+ // Emit gathers, of which there may be multiple batches based on the target
+ // resource as gathers are 1:1 per target.
+ for (auto *storageResource : resourceGathers) {
+ auto resourceSize = indexSet.get(storageResource->totalSize);
+ auto uploadedResource = buildParameterGather(
+ storageResource->loc, awaitTimepoint, affinityAttr, resourceType,
+ resourceSize, storageResource->spans, indexSet, builder);
uploadTimepoints.push_back(uploadedResource.timepoint);
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp
index f63f413..b34beb3 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/Patterns.cpp
@@ -33,45 +33,6 @@
return builder.create<IREE::VM::RodataInlineOp>(loc, attr);
}
-struct LoadOpConversion
- : public OpConversionPattern<IREE::IO::Parameters::LoadOp> {
- LoadOpConversion(MLIRContext *context, SymbolTable &importSymbols,
- TypeConverter &typeConverter, StringRef importName)
- : OpConversionPattern(context) {
- importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
- assert(importOp);
- }
- LogicalResult
- matchAndRewrite(IREE::IO::Parameters::LoadOp loadOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
- loadOp, importOp.getSymNameAttr(),
- IREE::VM::RefType::get(loadOp.getResult().getType()),
- ValueRange{
- adaptor.getDevice(),
- adaptor.getQueueAffinity(),
- adaptor.getWaitFence(),
- adaptor.getSignalFence(),
- getStringRodata(loadOp.getLoc(), adaptor.getSourceScopeAttr(),
- rewriter),
- getStringRodata(loadOp.getLoc(), adaptor.getSourceKeyAttr(),
- rewriter),
- adaptor.getSourceOffset(),
- adaptor.getQueueAffinity(),
- rewriter.create<IREE::VM::ConstI32Op>(
- loadOp.getLoc(), (uint32_t)adaptor.getMemoryTypes()),
- rewriter.create<IREE::VM::ConstI32Op>(
- loadOp.getLoc(), (uint32_t)adaptor.getBufferUsage()),
- adaptor.getLength(),
- });
- copyImportAttrs(importOp, callOp);
- return success();
- }
-
-private:
- mutable IREE::VM::ImportOp importOp;
-};
-
// TODO(benvanik): make a vm.rodata.table or something that returns the
// offset/length and data buffers. We could then do a whole-program analysis to
// build a single data table with multiple views into it.
@@ -145,6 +106,65 @@
return clonedBuffer;
}
+struct LoadOpConversion
+ : public OpConversionPattern<IREE::IO::Parameters::LoadOp> {
+ LoadOpConversion(MLIRContext *context, SymbolTable &importSymbols,
+ TypeConverter &typeConverter, StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+ LogicalResult
+ matchAndRewrite(IREE::IO::Parameters::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto [keyTable, keyData] =
+ buildKeyTable(loadOp.getLoc(), adaptor.getSourceKeysAttr(), rewriter);
+ SmallVector<Value> targetOffsets(
+ adaptor.getSourceOffsets().size(),
+ rewriter.create<IREE::VM::ConstI64Op>(loadOp.getLoc(), 0));
+ auto spans =
+ buildIndirectSpans(loadOp.getLoc(), adaptor.getSourceOffsets(),
+ targetOffsets, adaptor.getLengths(), rewriter);
+ auto bufferType =
+ IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>());
+ auto listType = IREE::VM::RefType::get(IREE::VM::ListType::get(bufferType));
+ auto callOp = rewriter.create<IREE::VM::CallOp>(
+ loadOp.getLoc(), importOp.getSymNameAttr(),
+ TypeRange{
+ listType,
+ },
+ ValueRange{
+ adaptor.getDevice(),
+ adaptor.getQueueAffinity(),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ getStringRodata(loadOp.getLoc(), adaptor.getSourceScopeAttr(),
+ rewriter),
+ adaptor.getQueueAffinity(),
+ rewriter.create<IREE::VM::ConstI32Op>(
+ loadOp.getLoc(), (uint32_t)adaptor.getMemoryTypes()),
+ rewriter.create<IREE::VM::ConstI32Op>(
+ loadOp.getLoc(), (uint32_t)adaptor.getBufferUsage()),
+ keyTable,
+ keyData,
+ spans,
+ });
+ copyImportAttrs(importOp, callOp);
+ SmallVector<Value> buffers;
+ buffers.reserve(targetOffsets.size());
+ for (size_t i = 0; i < targetOffsets.size(); ++i) {
+ buffers.push_back(rewriter.create<IREE::VM::ListGetRefOp>(
+ loadOp.getLoc(), bufferType, callOp.getResult(0),
+ rewriter.create<IREE::VM::ConstI32Op>(loadOp.getLoc(), (int32_t)i)));
+ }
+ rewriter.replaceOp(loadOp, buffers);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
struct GatherOpConversion
: public OpConversionPattern<IREE::IO::Parameters::GatherOp> {
GatherOpConversion(MLIRContext *context, SymbolTable &importSymbols,
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir
index eaed16f..412e810 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/parameter_ops.mlir
@@ -3,17 +3,31 @@
// CHECK-LABEL: @parameterLoad
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[QUEUE_AFFINITY:.+]]: i64, %[[WAIT:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL:.+]]: !vm.ref<!hal.fence>)
-func.func @parameterLoad(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence) -> !hal.buffer {
+func.func @parameterLoad(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence) -> (!hal.buffer, !hal.buffer) {
%c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
- // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
- // CHECK: %[[TARGET_BUFFER:.+]] = vm.call @io_parameters.load
+ %c101 = arith.constant 101 : index
+ // CHECK-DAG: %[[KEY_TABLE:.+]] = vm.rodata.inline : !vm.buffer = dense<[0, 4, 4, 4]> : vector<4xi32>
+ // CHECK-DAG: %[[KEY_DATA:.+]] = vm.rodata.inline : !vm.buffer = #util.composite<8xi8, [
+ // CHECK-NEXT: "key0",
+ // CHECK-NEXT: "key1",
+ // CHECK-NEXT: ]>
+ // CHECK-DAG: %[[SPANS:.+]] = vm.rodata.inline : !vm.buffer = dense<[50, 0, 100, 51, 0, 101]> : vector<6xi64>
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.rodata.inline {{.+}} = "scope"
+ // CHECK: %[[TARGET_BUFFERS:.+]] = vm.call @io_parameters.load
// CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
- // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %[[QUEUE_AFFINITY]], %c48, %c527363, %c100)
- %target_buffer = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source("scope"::"key")[%c50_i64] type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") : !hal.buffer{%c100}
- // CHECK: return %[[TARGET_BUFFER]]
- return %target_buffer : !hal.buffer
+ // CHECK-SAME: %[[SCOPE]], %[[QUEUE_AFFINITY]], %c48, %c527363, %[[KEY_TABLE]], %[[KEY_DATA]], %[[SPANS]])
+ %target_buffers:2 = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") {
+ "scope"::"key0"[%c50_i64] : !hal.buffer{%c100},
+ "scope"::"key1"[%c51_i64] : !hal.buffer{%c101}
+ }
+ // CHECK-DAG: %[[C0:.+]] = vm.const.i32 0
+ // CHECK-DAG: %[[TARGET_BUFFER0:.+]] = vm.list.get.ref %[[TARGET_BUFFERS]], %[[C0]]
+ // CHECK-DAG: %[[C1:.+]] = vm.const.i32 1
+ // CHECK-DAG: %[[TARGET_BUFFER1:.+]] = vm.list.get.ref %[[TARGET_BUFFERS]], %[[C1]]
+ // CHECK: return %[[TARGET_BUFFER0]], %[[TARGET_BUFFER1]]
+ return %target_buffers#0, %target_buffers#1 : !hal.buffer, !hal.buffer
}
// -----
@@ -23,12 +37,20 @@
func.func @parameterLoadNoScope(%device: !hal.device, %queue_affinity: i64, %wait: !hal.fence, %signal: !hal.fence) -> !hal.buffer {
%c50_i64 = arith.constant 50 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[SCOPE:.+]] = vm.const.ref.zero : !vm.buffer
- // CHECK-DAG: %[[KEY:.+]] = vm.rodata.inline {{.+}} = "key"
- // CHECK: %[[TARGET_BUFFER:.+]] = vm.call @io_parameters.load
+ // CHECK-DAG: %[[KEY_TABLE:.+]] = vm.rodata.inline : !vm.buffer = dense<[0, 3]> : vector<2xi32>
+ // CHECK-DAG: %[[KEY_DATA:.+]] = vm.rodata.inline : !vm.buffer = #util.composite<3xi8, [
+ // CHECK-NEXT: "key",
+ // CHECK-NEXT: ]>
+ // CHECK-DAG: %[[SPANS:.+]] = vm.rodata.inline : !vm.buffer = dense<[50, 0, 100]> : vector<3xi64>
+ // CHECK-DAG: %[[SCOPE:.+]] = vm.const.ref.zero : !vm.buffer
+ // CHECK: %[[TARGET_BUFFERS:.+]] = vm.call @io_parameters.load
// CHECK-SAME: (%[[DEVICE]], %[[QUEUE_AFFINITY]], %[[WAIT]], %[[SIGNAL]],
- // CHECK-SAME: %[[SCOPE]], %[[KEY]], %c50, %[[QUEUE_AFFINITY]], %c48, %c527363, %c100)
- %target_buffer = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) source("key")[%c50_i64] type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") : !hal.buffer{%c100}
+ // CHECK-SAME: %[[SCOPE]], %[[QUEUE_AFFINITY]], %c48, %c527363, %[[KEY_TABLE]], %[[KEY_DATA]], %[[SPANS]])
+ %target_buffer = io_parameters.load<%device : !hal.device> affinity(%queue_affinity) wait(%wait) signal(%signal) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") {
+ "key"[%c50_i64] : !hal.buffer{%c100}
+ }
+ // CHECK-DAG: %[[C0:.+]] = vm.const.i32 0
+ // CHECK-DAG: %[[TARGET_BUFFER:.+]] = vm.list.get.ref %[[TARGET_BUFFERS]], %[[C0]]
// CHECK: return %[[TARGET_BUFFER]]
return %target_buffer : !hal.buffer
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp
index 817fb36..e483458 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/Patterns.cpp
@@ -40,7 +40,7 @@
// Derive the allocation requirements.
auto resourceType =
- llvm::cast<IREE::Stream::ResourceType>(loadOp.getResult().getType());
+ cast<IREE::Stream::ResourceType>(loadOp.getResults().front().getType());
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
if (failed(deriveAllowedResourceBufferBits(loc, resourceType, memoryTypes,
@@ -49,13 +49,18 @@
}
// Queue operation, which acts like an allocation.
- Value result = rewriter.create<IREE::IO::Parameters::LoadOp>(
- loc, rewriter.getType<IREE::HAL::BufferType>(), device, queueAffinity,
- waitFence, signalFence, adaptor.getSourceScopeAttr(),
- adaptor.getSourceKeyAttr(), adaptor.getSourceOffset(), memoryTypes,
- bufferUsage, adaptor.getResultSize());
+ SmallVector<Type> newResultTypes(loadOp.getResults().size(),
+ rewriter.getType<IREE::HAL::BufferType>());
+ auto newOp = rewriter.create<IREE::IO::Parameters::LoadOp>(
+ loc, newResultTypes, device, queueAffinity, waitFence, signalFence,
+ adaptor.getSourceScopeAttr(), adaptor.getSourceKeysAttr(),
+ adaptor.getSourceOffsets(), memoryTypes, bufferUsage,
+ adaptor.getResultSizes());
- rewriter.replaceOp(loadOp, {result, signalFence});
+ SmallVector<Value> resultReplacements;
+ llvm::append_range(resultReplacements, newOp.getResults());
+ resultReplacements.push_back(signalFence);
+ rewriter.replaceOp(loadOp, resultReplacements);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
index aab010c..eb75719 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -1,21 +1,26 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s
// CHECK-LABEL: @parameterLoad
-// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence)
-func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.buffer, !hal.fence)
+func.func @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
+ %c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
// CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
- // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
+ // CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
- // CHECK-SAME: source("scope"::"key")[%c50_i64]
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
- // CHECK-SAME: : !hal.buffer
- %result, %result_timepoint = stream.parameter.load await(%wait) => "scope"::"key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
- // CHECK: return %[[BUFFER]], %[[SIGNAL]]
- return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
+ // CHECK-NEXT: "scope"::"key0"[%c50_i64] : !hal.buffer{%c100}
+ // CHECK-NEXT: "scope"::"key1"[%c51_i64] : !hal.buffer{%c101}
+ %results:2, %result_timepoint = stream.parameter.load await(%wait) => {
+ "scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
+ "scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c101}
+ } => !stream.timepoint
+ // CHECK: return %[[BUFFERS]]#0, %[[BUFFERS]]#1, %[[SIGNAL]]
+ return %results#0, %results#1, %result_timepoint : !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
}
// -----
@@ -30,10 +35,11 @@
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
- // CHECK-SAME: source("key")[%c50_i64]
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
- // CHECK-SAME: : !hal.buffer
- %result, %result_timepoint = stream.parameter.load await(%wait) => "key"[%c50_i64] : !stream.resource<constant>{%c100} => !stream.timepoint
+ // CHECK-NEXT: "key"[%c50_i64] : !hal.buffer{%c100}
+ %result, %result_timepoint = stream.parameter.load await(%wait) => {
+ "key"[%c50_i64] : !stream.resource<constant>{%c100}
+ } => !stream.timepoint
// CHECK: return %[[BUFFER]], %[[SIGNAL]]
return %result, %result_timepoint : !stream.resource<constant>, !stream.timepoint
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp
index bae1187..050aaf5 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.cpp
@@ -56,9 +56,95 @@
}
//===----------------------------------------------------------------------===//
+// custom<ParameterLoadOperations>(
+// $source_scope, $source_keys, $source_offsets,
+// type($results), $result_sizes)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseParameterLoadOperations(
+ OpAsmParser &parser, StringAttr &sourceScopeAttr, ArrayAttr &sourceKeysAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &sourceOffsets,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultSizes) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> sourceKeyAttrs;
+ do {
+ StringAttr rowSourceScopeAttr;
+ StringAttr sourceKeyAttr;
+ OpAsmParser::UnresolvedOperand sourceOffset;
+ Type resultType;
+ OpAsmParser::UnresolvedOperand resultSize;
+ if (failed(parseParameterReference(parser, rowSourceScopeAttr,
+ sourceKeyAttr)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(sourceOffset)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.parseLBrace()) ||
+ failed(parser.parseOperand(resultSize)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ if (!sourceScopeAttr) {
+ sourceScopeAttr = rowSourceScopeAttr;
+ } else if (rowSourceScopeAttr != sourceScopeAttr) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "each operation must use the same scope");
+ }
+ sourceKeyAttrs.push_back(sourceKeyAttr);
+ sourceOffsets.push_back(sourceOffset);
+ resultTypes.push_back(resultType);
+ resultSizes.push_back(resultSize);
+ } while (succeeded(parser.parseOptionalComma()));
+ sourceKeysAttr = builder.getArrayAttr(sourceKeyAttrs);
+ return success();
+}
+
+static void printParameterLoadOperations(OpAsmPrinter &p, Operation *op,
+ StringAttr sourceScopeAttr,
+ ArrayAttr sourceKeysAttr,
+ ValueRange sourceOffsets,
+ TypeRange resultTypes,
+ ValueRange resultSizes) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(sourceKeysAttr.getAsRange<StringAttr>(), sourceOffsets,
+ resultTypes, resultSizes),
+ [&](std::tuple<StringAttr, Value, Type, Value> it) {
+ auto [sourceKeyAttr, sourceOffset, resultType, resultSize] = it;
+ printParameterReference(p, op, sourceScopeAttr, sourceKeyAttr);
+ p << "[";
+ p.printOperand(sourceOffset);
+ p << "] : ";
+ p.printType(resultType);
+ p << "{";
+ p.printOperand(resultSize);
+ p << "}";
+ },
+ [&]() {
+ p << ',';
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// io_parameters.load
//===----------------------------------------------------------------------===//
+LogicalResult LoadOp::verify() {
+ LoadOp op = *this;
+ size_t expectedCount = op.getSourceKeys().size();
+ if (op.getSourceOffsets().size() != expectedCount ||
+ op.getLengths().size() != expectedCount) {
+ return op.emitOpError() << "requires that the source keys, source offsets, "
+ "and result sizes are all 1:1";
+ }
+ return success();
+}
+
void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): fold hal.buffer.subspan on the result into parameters.
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td
index 9ef6726..1f03b8f 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td
@@ -30,14 +30,16 @@
let opDocGroup = OpGroupParameterOps in {
def IOParameters_LoadOp : IOParameters_Op<"load", [
+ AttrSizedOperandSegments,
Util_SizeAwareOp,
]> {
- let summary = [{reads a parameter from a parameter scope}];
+ let summary = [{reads one or more parameters from a parameter scope}];
let description = [{
- Asynchronously reads a parameter from an external parameter provider and
- returns the resulting buffer. Depending on the parameter and buffer types
- this may alias existing cached storage or be directly mapped to the
- parameter origin or result in a copy as if an allocate + read had been used.
+ Asynchronously reads one or more parameters from an external parameter
+ provider and returns the resulting buffers. Depending on the parameter and
+ buffer types this may alias existing cached storage or be directly mapped to
+ the parameter origin or result in a copy as if an allocate + read had been
+ used.
}];
let arguments = (ins
@@ -46,14 +48,14 @@
HAL_Fence:$wait_fence,
HAL_Fence:$signal_fence,
OptionalAttr<StrAttr>:$source_scope,
- StrAttr:$source_key,
- I64:$source_offset,
+ StrArrayAttr:$source_keys,
+ Variadic<I64>:$source_offsets,
HAL_MemoryTypeBitfieldAttr:$memory_types,
HAL_BufferUsageBitfieldAttr:$buffer_usage,
- HAL_DeviceSize:$length
+ Variadic<HAL_DeviceSize>:$lengths
);
let results = (outs
- HAL_Buffer:$result
+ Variadic<HAL_Buffer>:$results
);
let assemblyFormat = [{
@@ -61,19 +63,23 @@
`affinity` `(` $queue_affinity `)`
`wait` `(` $wait_fence `)`
`signal` `(` $signal_fence `)`
- `source` `(` custom<ParameterReference>($source_scope, $source_key) `)`
- `` `[` $source_offset `]`
`type` `(` $memory_types `)`
`usage` `(` $buffer_usage `)`
- `:` custom<SizeAwareType>(type($result), $length)
+ `{`
+ custom<ParameterLoadOperations>(
+ $source_scope, $source_keys, $source_offsets,
+ type($results), $lengths)
+ `}`
attr-dict-with-keyword
}];
let extraClassDeclaration = [{
Value getOperandSize(unsigned idx) { return {}; }
- Value getResultSize(unsigned idx) { return getLength(); }
+ Value getResultSize(unsigned idx) { return getLengths()[idx]; }
}];
+ let hasVerifier = 1;
+
let hasCanonicalizer = 1;
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir
index 98e6bbc..e60c73e 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/parameter_ops.mlir
@@ -13,17 +13,16 @@
// CHECK-SAME: affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]])
// CHECK-SAME: signal(%[[SIGNAL]])
- // CHECK-SAME: source("scope"::"w0")[%[[OFFSET]]]
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
- // CHECK-SAME: : !hal.buffer{%[[LENGTH]]}
+ // CHECK-NEXT: "scope"::"w0"[%[[OFFSET]]] : !hal.buffer{%[[LENGTH]]}
%0 = io_parameters.load<%device : !hal.device>
affinity(%affinity)
wait(%wait)
signal(%signal)
- source("scope"::"w0")[%offset]
type("DeviceVisible|DeviceLocal")
- usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
- : !hal.buffer{%length}
+ usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") {
+ "scope"::"w0"[%offset] : !hal.buffer{%length}
+ }
return
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir
index a062262..4bfda08 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/io_parameters.imports.mlir
@@ -6,13 +6,13 @@
%wait_fence : !vm.ref<!hal.fence>,
%signal_fence : !vm.ref<!hal.fence>,
%source_scope : !vm.buffer,
- %source_key : !vm.buffer,
- %source_offset : i64,
%target_queue_affinity : i64,
%target_memory_type : i32,
%target_buffer_usage : i32,
- %length : i64
-) -> !vm.ref<!hal.buffer>
+ %key_table : !vm.buffer,
+ %key_data : !vm.buffer,
+ %spans : !vm.buffer
+) -> !vm.list<!vm.ref<!hal.buffer>>
vm.import private @gather(
%device : !vm.ref<!hal.device>,
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index 09f7415..71eccb9 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -25,6 +25,10 @@
const iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_command_category_t required_categories) {
+ if (IREE_UNLIKELY(!validation_state->is_recording)) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "command buffer is not in a recording state");
+ }
if (!iree_all_bits_set(command_buffer->allowed_categories,
required_categories)) {
#if IREE_STATUS_MODE
diff --git a/runtime/src/iree/hal/drivers/local_task/task_queue.c b/runtime/src/iree/hal/drivers/local_task/task_queue.c
index 02f7942..4609d9b 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_queue.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_queue.c
@@ -201,21 +201,19 @@
// NOTE: it's ok for there to be no command buffers - in that case the
// submission was purely for synchronization.
- if (cmd->command_buffer_count > 0) {
- for (iree_host_size_t i = 0; i < cmd->command_buffer_count; ++i) {
- if (iree_hal_task_command_buffer_isa(cmd->command_buffers[i])) {
- status = iree_hal_task_command_buffer_issue(
- cmd->command_buffers[i], &cmd->queue->state,
- cmd->task.header.completion_task, cmd->arena, pending_submission);
- iree_hal_command_buffer_release(cmd->command_buffers[i]);
- cmd->command_buffers[i] = NULL;
- } else {
- status = iree_make_status(
- IREE_STATUS_UNIMPLEMENTED,
- "unsupported command buffer type for task queue submission");
- }
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
+ for (iree_host_size_t i = 0; i < cmd->command_buffer_count; ++i) {
+ if (iree_hal_task_command_buffer_isa(cmd->command_buffers[i])) {
+ status = iree_hal_task_command_buffer_issue(
+ cmd->command_buffers[i], &cmd->queue->state,
+ cmd->task.header.completion_task, cmd->arena, pending_submission);
+ iree_hal_command_buffer_release(cmd->command_buffers[i]);
+ cmd->command_buffers[i] = NULL;
+ } else {
+ status = iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "unsupported command buffer type for task queue submission");
}
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
}
IREE_TRACE_ZONE_END(z0);
diff --git a/runtime/src/iree/hal/utils/file_cache.c b/runtime/src/iree/hal/utils/file_cache.c
index c2bde9d..010de5c 100644
--- a/runtime/src/iree/hal/utils/file_cache.c
+++ b/runtime/src/iree/hal/utils/file_cache.c
@@ -159,6 +159,8 @@
entry->file = file;
iree_hal_file_retain(entry->file);
+ file_cache->entries[file_cache->entry_count++] = entry;
+
return iree_ok_status();
}
diff --git a/runtime/src/iree/io/parameter_index_provider.c b/runtime/src/iree/io/parameter_index_provider.c
index f36a1c2..38b83c4 100644
--- a/runtime/src/iree/io/parameter_index_provider.c
+++ b/runtime/src/iree/io/parameter_index_provider.c
@@ -11,7 +11,7 @@
// Limit concurrent operations to avoid blowing the stack. This is arbitrary and
// if we wanted to support more we could switch to using heap allocations or
// a growable stack scratchpad.
-#define IREE_IO_PARAMETER_INDEX_PROVIDER_CONCURRENT_OPERATION_LIMIT 128
+#define IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY 8
typedef struct iree_io_parameter_index_provider_t {
iree_io_parameter_provider_t base;
@@ -42,8 +42,8 @@
IREE_TRACE_ZONE_APPEND_TEXT(z0, scope.data, scope.size);
max_concurrent_operations =
- iree_min(max_concurrent_operations,
- IREE_IO_PARAMETER_INDEX_PROVIDER_CONCURRENT_OPERATION_LIMIT);
+ iree_max(1, iree_min(max_concurrent_operations,
+ IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY));
iree_io_parameter_index_provider_t* provider = NULL;
iree_host_size_t total_size = sizeof(*provider) + scope.size;
@@ -193,8 +193,9 @@
iree_hal_memory_access_format(required_access, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
- "parameter storage does not support the requested access "
+ "parameter `%.*s` storage does not support the requested access "
"type; parameter allows %.*s, operation requires %.*s",
+ (int)entry->key.size, entry->key.data,
(int)allowed_memory_access_str.size, allowed_memory_access_str.data,
(int)required_memory_access_str.size, required_memory_access_str.data);
#else
@@ -203,15 +204,470 @@
}
if (offset + length > entry->length) {
- return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
- "parameter range out of bounds (offset=%" PRIu64
- ", length=%" PRIu64 ", size=%" PRIu64 ")",
- offset, length, entry->length);
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "parameter `%.*s` range out of bounds (offset=%" PRIu64
+ ", length=%" PRIu64 ", size=%" PRIu64 ")",
+ (int)entry->key.size, entry->key.data, offset, length, entry->length);
}
return iree_ok_status();
}
+// Stateful batch management of multiple parameter operations.
+//
+// The batch distributes operations over multiple timelines based on how much
+// I/O is done on each. It's possible for pathologically bad latency if there
+// are large and small operations interleaved as all large operations may end up
+// serialized on one timeline and all small ones on the other.
+// We distribute by tracking the total bytes outstanding on each timeline and
+// always placing the next operation on the one with the fewest. This assumes
+// that all I/O happens at roughly the same speed but if parameters come from
+// different files on different devices that may not be the case. It's better
+// than doing nothing, though.
+//
+// Though we mostly focus on file I/O we also support DMA operations such as
+// splats (synthetic parameters) and copies (device->device transfers) by way
+// of a command buffer we build and submit when needed. Today there is just a
+// single command buffer we submit with zero barriers which should be equivalent
+// to spreading it out over multiple timelines.
+//
+// NOTE: we expect count == 0 to have been handled by callers to avoid the
+// overhead of the batch setup and submission but it's valid to have a zero
+// count.
+typedef struct iree_io_parameter_op_batch_t {
+ // Parameter provider sourcing the parameter metadata.
+ iree_io_parameter_index_provider_t* provider; // unretained
+ // Device hosting the batch operation.
+ iree_hal_device_t* device; // unretained
+ // Queue affinity indicating where batch operations can run.
+ iree_hal_queue_affinity_t queue_affinity;
+
+ // Semaphores that must be waited on prior to any operations begin.
+ iree_hal_semaphore_list_t wait_semaphore_list;
+ // Semaphores that must be signaled after all operations complete.
+ iree_hal_semaphore_list_t signal_semaphore_list;
+
+ // Number of concurrent timelines available for processing the batch.
+ // Expects 0 < concurrency <= IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY.
+ // Not all timelines may be used and timeline_live_count should be checked to
+ // see which are active.
+ iree_host_size_t concurrency;
+ // Number of timelines which have had operations scheduled against them.
+ // This allows us to filter out timelines that are idle upon completion
+ // and avoid unneeded waits. 0 if no timeline was used (count = 0, etc).
+ iree_host_size_t timeline_live_count;
+ // Sum of byte read/writes outstanding per timeline as a proxy for the amount
+ // of work performed in a particular timeline.
+ uint64_t
+ timeline_bytes_outstanding[IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY];
+ // Semaphore per timeline.
+ iree_hal_semaphore_t*
+ timeline_semaphores[IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY];
+ // Current payload per timeline; when the semaphore reaches this value the
+ // timeline has quiesced.
+ uint64_t timeline_values[IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY];
+
+ // On-demand allocated command buffer used for transfer operations (usually
+ // splats, but could be copies to/from device buffers).
+ iree_hal_command_buffer_t* transfer_command_buffer;
+ // Sum of byte read/writes in the transfer command buffer as a proxy for
+ // the amount of work performed within the command buffer. We expect transfer
+ // operations to be cheaper than file I/O operations but are not trying to be
+ // precise here.
+ uint64_t transfer_bytes_outstanding;
+} iree_io_parameter_op_batch_t;
+
+// Begins a parameter operation batch against the given |provider|.
+// Operations will be scheduled on |device| with |queue_affinity|. All batch
+// operations will wait until |wait_semaphore_list| has been reached and after
+// all batch operations complete |signal_semaphore_list| will be signaled.
+// Upon return callers are required to call iree_io_parameter_op_batch_end
+// regardless of whether the caller encounters an error.
+static void iree_io_parameter_op_batch_begin(
+ iree_io_parameter_index_provider_t* provider, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_io_parameter_op_batch_t* IREE_RESTRICT out_batch) {
+ IREE_ASSERT_ARGUMENT(provider);
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(out_batch);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ memset(out_batch, 0, sizeof(*out_batch));
+
+ out_batch->provider = provider;
+ out_batch->device = device;
+ out_batch->queue_affinity = queue_affinity;
+
+ out_batch->wait_semaphore_list = wait_semaphore_list;
+ out_batch->signal_semaphore_list = signal_semaphore_list;
+
+ // We could limit the concurrency from the max based on the batch size but
+ // since the compiler batches everything and most models go over the default
+ // max concurrency this is fine for now.
+ out_batch->concurrency =
+ iree_max(1, iree_min(provider->max_concurrent_operations,
+ IREE_IO_PARAMETER_OP_BATCH_MAX_CONCURRENCY));
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+// Resolves the parameter entry from |enumerator| at index |i|.
+// |access| indicates the required access permissions to the parameter storage.
+// Returns the entry, the span indicating source/target ranges, and optionally
+// a file (NULL if a splat). |out_file| is retained and must be released by the
+// caller if set.
+static iree_status_t iree_io_parameter_op_batch_resolve_entry(
+ const iree_io_parameter_op_batch_t* batch, iree_string_view_t scope,
+ iree_io_parameter_enumerator_t enumerator, iree_host_size_t i,
+ iree_hal_memory_access_t access,
+ const iree_io_parameter_index_entry_t** IREE_RESTRICT out_entry,
+ iree_io_parameter_span_t* IREE_RESTRICT out_span,
+ iree_hal_file_t** IREE_RESTRICT out_file) {
+ IREE_ASSERT_ARGUMENT(out_entry);
+ IREE_ASSERT_ARGUMENT(out_span);
+ IREE_ASSERT_ARGUMENT(out_file);
+ *out_entry = NULL;
+ memset(out_span, 0, sizeof(*out_span));
+ *out_file = NULL;
+
+ // Fetch the next parameter to copy and its buffer range.
+ iree_string_view_t key = iree_string_view_empty();
+ iree_io_parameter_span_t span = {0};
+ IREE_RETURN_IF_ERROR(enumerator.fn(enumerator.user_data, i, &key, &span));
+
+ // Lookup the parameter metadata and get its backing file.
+ const iree_io_parameter_index_entry_t* entry = NULL;
+ iree_hal_file_t* file = NULL; // retained, NULL if splat
+ IREE_RETURN_IF_ERROR(iree_io_parameter_index_provider_resolve(
+ batch->provider, batch->device, batch->queue_affinity, scope, key, access,
+ &entry, &file));
+
+ // Validate the parameter range is in-bounds.
+ iree_status_t status = iree_io_validate_parameter_range(
+ access, entry, span.parameter_offset, span.length);
+
+ if (iree_status_is_ok(status)) {
+ *out_entry = entry;
+ *out_span = span;
+ *out_file = file;
+ } else {
+ iree_hal_file_release(file);
+ }
+ return status;
+}
+
+typedef struct {
+ iree_hal_semaphore_list_t wait_semaphore_list;
+ iree_hal_semaphore_list_t signal_semaphore_list;
+ uint64_t scratch_values[2]; // wait/signal payload values
+} iree_io_parameter_op_step_t;
+
+// Selects a timeline with the fewest bytes outstanding and accounts for the new
+// |op_byte_length| bytes on that timeline. Returns semaphore lists the caller
+// must wait on before performing their operation and signal after their
+// operation completes.
+static iree_status_t iree_io_parameter_op_batch_advance_timeline(
+ iree_io_parameter_op_batch_t* batch, uint64_t op_byte_length,
+ iree_io_parameter_op_step_t* IREE_RESTRICT out_step) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(out_step);
+ memset(out_step, 0, sizeof(*out_step));
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, op_byte_length);
+
+ // Find the timeline with the fewest outstanding bytes.
+ // Linear scan as the number of timelines is expected to be small.
+ uint64_t smallest_value = batch->timeline_bytes_outstanding[0];
+ iree_host_size_t smallest_index = 0;
+ for (iree_host_size_t i = 1; i < batch->concurrency; ++i) {
+ if (batch->timeline_bytes_outstanding[i] < smallest_value) {
+ smallest_value = batch->timeline_bytes_outstanding[i];
+ smallest_index = i;
+ }
+ }
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)smallest_index);
+ const iree_host_size_t timeline_index = smallest_index;
+
+ // Acquire the timeline semaphore used for this operation.
+ // We create the semaphores on-demand so that in cases where we don't perform
+ // any operations (loads that perform synchronous imports) or only a small
+ // amount (one transfer command buffer or just a handful of operations) we
+ // don't create so much garbage. The intent is that HAL devices pool
+ // semaphores but not all do - if we made that an expected requirement we
+ // could simplify this.
+ iree_hal_semaphore_t* timeline_semaphore =
+ batch->timeline_semaphores[timeline_index];
+ const bool is_first_timeline_use = timeline_semaphore == NULL;
+ if (!timeline_semaphore) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_semaphore_create(
+ batch->device, batch->timeline_values[timeline_index],
+ &batch->timeline_semaphores[timeline_index]));
+ timeline_semaphore = batch->timeline_semaphores[timeline_index];
+ }
+ const uint64_t previous_timeline_value =
+ batch->timeline_values[timeline_index];
+ const uint64_t next_timeline_value = ++batch->timeline_values[timeline_index];
+
+ // Account for the bytes processed by the operation, which is a good enough
+ // metric for distribution (assuming all operations take about the same amount
+ // of memory or I/O bandwidth).
+ batch->timeline_bytes_outstanding[timeline_index] += op_byte_length;
+
+ // Select the wait semaphore list; the first wave of operations all wait on
+ // the original wait semaphore list provided by the initiator.
+ if (is_first_timeline_use) {
+ // First use of this timeline; wait on incoming list and begin the timeline.
+ IREE_ASSERT_EQ(timeline_index, batch->timeline_live_count);
+ ++batch->timeline_live_count;
+ out_step->wait_semaphore_list = batch->wait_semaphore_list;
+ } else {
+ // Continuation of the selected timeline.
+ out_step->scratch_values[0] = previous_timeline_value;
+ out_step->wait_semaphore_list.count = 1;
+ out_step->wait_semaphore_list.semaphores =
+ &batch->timeline_semaphores[timeline_index];
+ out_step->wait_semaphore_list.payload_values = &out_step->scratch_values[0];
+ }
+
+ // Signal the continuation of the timeline.
+ out_step->scratch_values[1] = next_timeline_value;
+ out_step->signal_semaphore_list.count = 1;
+ out_step->signal_semaphore_list.semaphores =
+ &batch->timeline_semaphores[timeline_index];
+ out_step->signal_semaphore_list.payload_values = &out_step->scratch_values[1];
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Enqueues a queue-ordered allocation.
+// A timeline is selected based on utilization and the following operation is
+// guaranteed to select the same timeline to ensure the allocation and
+// operation are serialized and the wait has a higher chance of being elided.
+static iree_status_t iree_io_parameter_op_batch_enqueue_alloca(
+ iree_io_parameter_op_batch_t* batch, iree_hal_allocator_pool_t pool,
+ iree_hal_buffer_params_t params, iree_device_size_t allocation_size,
+ iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(out_buffer);
+ *out_buffer = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // By passing 0 for the operation size we ensure that subsequent operations
+ // select the same timeline because the timeline with the fewest outstanding
+ // bytes will be returned after this. This intuitively seems good as we keep
+ // the allocation and operation using it serialized within a single timeline
+ // and allow devices that can elide back-to-back barriers to do so, but we may
+ // find that we also want to distribute allocations.
+ iree_io_parameter_op_step_t step;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_io_parameter_op_batch_advance_timeline(
+ batch, /*op_byte_length=*/0, &step));
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_device_queue_alloca(batch->device, batch->queue_affinity,
+ step.wait_semaphore_list,
+ step.signal_semaphore_list, pool, params,
+ allocation_size, out_buffer));
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Enqueues a splat operation in the batch into the |buffer| range.
+// Splats get routed to a transfer command buffer that we'll submit at the end
+// of the batch. This avoids the need for us to check all of the operations
+// ahead of time at the cost of potentially acquiring more semaphores than we
+// need in cases where everything is a splat. Splats are pretty much only useful
+// for testing/development, though, so it's ok to not be super efficient here.
+static iree_status_t iree_io_parameter_op_batch_enqueue_splat(
+ iree_io_parameter_op_batch_t* batch, iree_hal_buffer_t* buffer,
+ iree_device_size_t buffer_offset, iree_device_size_t length,
+ const void* pattern, iree_host_size_t pattern_length) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(buffer);
+ IREE_ASSERT_ARGUMENT(pattern);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Create the transfer command buffer on first use.
+ if (!batch->transfer_command_buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_create(
+ batch->device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ IREE_HAL_COMMAND_CATEGORY_TRANSFER, batch->queue_affinity, 0,
+ &batch->transfer_command_buffer));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_begin(batch->transfer_command_buffer));
+ }
+
+ // Add the splat fill to the command buffer.
+ // Parameter ranges cannot overlap so there's no barrier required.
+ batch->transfer_bytes_outstanding += length;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_fill_buffer(batch->transfer_command_buffer,
+ buffer, buffer_offset, length,
+ pattern, pattern_length));
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Enqueues a file read operation in the batch.
+static iree_status_t iree_io_parameter_op_batch_enqueue_file_read(
+ iree_io_parameter_op_batch_t* batch, iree_hal_file_t* source_file,
+ uint64_t source_file_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_buffer_offset, iree_device_size_t length,
+ uint32_t flags) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(source_file);
+ IREE_ASSERT_ARGUMENT(target_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_parameter_op_step_t step;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_op_batch_advance_timeline(batch, length, &step));
+
+ iree_status_t status = iree_hal_device_queue_read(
+ batch->device, batch->queue_affinity, step.wait_semaphore_list,
+ step.signal_semaphore_list, source_file, source_file_offset,
+ target_buffer, target_buffer_offset, length, flags);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Enqueues a file write operation in the batch.
+static iree_status_t iree_io_parameter_op_batch_enqueue_file_write(
+ iree_io_parameter_op_batch_t* batch, iree_hal_buffer_t* source_buffer,
+ iree_device_size_t source_buffer_offset, iree_hal_file_t* target_file,
+ uint64_t target_file_offset, iree_device_size_t length, uint32_t flags) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(source_buffer);
+ IREE_ASSERT_ARGUMENT(target_file);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_io_parameter_op_step_t step;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_io_parameter_op_batch_advance_timeline(batch, length, &step));
+
+ iree_status_t status = iree_hal_device_queue_write(
+ batch->device, batch->queue_affinity, step.wait_semaphore_list,
+ step.signal_semaphore_list, source_buffer, source_buffer_offset,
+ target_file, target_file_offset, length, flags);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Flushes any outstanding work in the |batch| and signals the user timeline.
+// Must only be called once at the end of the batch.
+static iree_status_t iree_io_parameter_op_batch_flush(
+ iree_io_parameter_op_batch_t* batch) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // If any transfers were performed we'll need to submit the command buffer we
+ // built during recording. Order doesn't matter so we can issue it alongside
+ // all of the other work by just appending it to an arbitrary timeline. We try
+ // to still balance things by selecting a timeline with the fewest operation
+ // bytes outstanding even if the cost of a byte differs between file I/O and
+ // pure DMA operations.
+ iree_status_t status = iree_ok_status();
+ if (batch->transfer_command_buffer) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z_transfer,
+ "iree_io_parameter_op_batch_flush_transfer");
+ status = iree_hal_command_buffer_end(batch->transfer_command_buffer);
+ iree_io_parameter_op_step_t step;
+ if (iree_status_is_ok(status)) {
+ status = iree_io_parameter_op_batch_advance_timeline(
+ batch, batch->transfer_bytes_outstanding, &step);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_execute(
+ batch->device, batch->queue_affinity, step.wait_semaphore_list,
+ step.signal_semaphore_list, 1, &batch->transfer_command_buffer);
+ }
+ IREE_TRACE_ZONE_END(z_transfer);
+ }
+
+ // Join all concurrent timelines and continue the user-provided timeline.
+ if (iree_status_is_ok(status)) {
+ // If no queue operations were performed (all load imports, 0 entries, etc)
+ // we need to issue a barrier to link the wait->signal semaphore lists.
+ if (batch->timeline_live_count == 0) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "pass-through wait-signal");
+ status = iree_hal_device_queue_barrier(
+ batch->device, batch->queue_affinity, batch->wait_semaphore_list,
+ batch->signal_semaphore_list);
+ } else {
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, "timeline set wait chain");
+ // Note that we allocate timelines on-demand up to timeline_live_count so
+ // we can just pass the [0, timeline_live_count) range here.
+ iree_hal_semaphore_list_t join_semaphore_list = {
+ .count = batch->timeline_live_count,
+ .semaphores = batch->timeline_semaphores,
+ .payload_values = batch->timeline_values,
+ };
+ status = iree_hal_device_queue_barrier(
+ batch->device, batch->queue_affinity, join_semaphore_list,
+ batch->signal_semaphore_list);
+ }
+ }
+
+ // Report the total number of bytes transferred by the batch.
+ IREE_TRACE({
+ uint64_t total_bytes = batch->transfer_bytes_outstanding;
+ for (iree_host_size_t i = 0; i < batch->concurrency; ++i) {
+ total_bytes += batch->timeline_bytes_outstanding[i];
+ }
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, total_bytes);
+ });
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Ends a parameter operation batch.
+// The provided |status| must be set to any failure that may have occurred
+// between when the batch began and this method was called. The status will be
+// passed on to the initiating caller by way of the signal semaphores being
+// immediately failed. Returns the status provided to allow for propagating
+// failures.
+static iree_status_t iree_io_parameter_op_batch_end(
+ iree_io_parameter_op_batch_t* batch, iree_status_t status) {
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(
+ z0, iree_status_code_string(iree_status_code(status)));
+
+ // If the batch recording succeeded we flush now to finish any pending
+ // operations and signal the semaphores.
+ if (iree_status_is_ok(status)) {
+ status = iree_io_parameter_op_batch_flush(batch);
+ }
+
+ // If the batch recording failed (or our flush did) we need to propagate that
+ // to the downstream user semaphores.
+ if (!iree_status_is_ok(status)) {
+ iree_hal_semaphore_list_fail(batch->signal_semaphore_list,
+ iree_status_clone(status));
+ }
+
+ // Resources are safe to release even if there are pending device operations
+ // as the device guarantees the resources remain live.
+ for (iree_host_size_t i = 0; i < batch->concurrency; ++i) {
+ iree_hal_semaphore_release(batch->timeline_semaphores[i]);
+ }
+ iree_hal_command_buffer_release(batch->transfer_command_buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
static void iree_io_file_handle_buffer_release(void* user_data,
iree_hal_buffer_t* buffer) {
iree_io_file_handle_release((iree_io_file_handle_t*)user_data);
@@ -222,489 +678,157 @@
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_params_t target_params,
- iree_device_size_t length,
- iree_hal_buffer_t** IREE_RESTRICT out_target_buffer) {
+ iree_string_view_t source_scope, iree_hal_buffer_params_t target_params,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator,
+ iree_io_parameter_emitter_t emitter) {
iree_io_parameter_index_provider_t* provider =
iree_io_parameter_index_provider_cast(base_provider);
IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
- // Lookup the parameter metadata and get its backing file.
- const iree_io_parameter_index_entry_t* source_entry = NULL;
- iree_hal_file_t* source_file = NULL; // retained, NULL if splat
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_io_parameter_index_provider_resolve(
- provider, device, queue_affinity, source_scope, source_key,
- IREE_HAL_MEMORY_ACCESS_READ, &source_entry, &source_file));
+ // Initialize the batch state.
+ iree_io_parameter_op_batch_t batch;
+ iree_io_parameter_op_batch_begin(provider, device, queue_affinity,
+ wait_semaphore_list, signal_semaphore_list,
+ &batch);
- // Validate the parameter range is in-bounds.
- iree_status_t status = iree_io_validate_parameter_range(
- IREE_HAL_MEMORY_ACCESS_READ, source_entry, source_offset, length);
-
- // Try first to reuse the file backing store directly as a buffer. This only
- // works with specific file types and with specific target usage. The most
- // common cases for this are when using parameters as staging sources (so host
- // memory is ok) or on unified memory systems (where host memory is device
- // memory) and the file was originally mapped. We could extend the conditions
- // in which we use this with some better file handle helpers that allow us to
- // map files that we already have open via other mechanisms (FILE, fd, etc).
- iree_hal_buffer_t* target_buffer = NULL;
- if (iree_status_is_ok(status) &&
- source_entry->type == IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE &&
- iree_io_file_handle_type(source_entry->storage.file.handle) ==
- IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
- iree_byte_span_t host_allocation =
- iree_io_file_handle_primitive(source_entry->storage.file.handle)
- .value.host_allocation;
- iree_hal_external_buffer_t external_buffer = {
- .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION,
- .flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE,
- .size = host_allocation.data_length,
- .handle =
- {
- .host_allocation =
- {
- .ptr = host_allocation.data +
- source_entry->storage.file.offset,
- },
- },
- };
- iree_hal_buffer_release_callback_t release_callback = {
- .fn = iree_io_file_handle_buffer_release,
- .user_data = source_entry->storage.file.handle,
- };
- iree_io_file_handle_retain(source_entry->storage.file.handle);
- iree_status_t import_status = iree_hal_allocator_import_buffer(
- iree_hal_device_allocator(device), target_params, &external_buffer,
- release_callback, &target_buffer);
- if (iree_status_is_ok(import_status)) {
- // Import succeeded - issue a barrier to preserve the async timeline.
- IREE_TRACE_ZONE_APPEND_TEXT(z0, "import succeeded");
- status = iree_hal_device_queue_barrier(
- device, queue_affinity, wait_semaphore_list, signal_semaphore_list);
- } else {
- // Failed to import - that's ok as we'll just do the full allocate + read.
- IREE_TRACE_ZONE_APPEND_TEXT(z0, "import failed");
- import_status = iree_status_ignore(import_status);
- iree_io_file_handle_release(source_entry->storage.file.handle);
- }
- }
-
- if (!target_buffer) {
- // Temporary semaphore for chaining the allocation and read.
- iree_hal_semaphore_t* temporary_semaphore = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_semaphore_create(device, 0ull, &temporary_semaphore);
- }
- uint64_t temporary_semaphore_value = 1ull;
- const iree_hal_semaphore_list_t alloca_semaphore_list = {
- .count = 1,
- .semaphores = &temporary_semaphore,
- .payload_values = &temporary_semaphore_value,
- };
-
- // Allocate the target buffer.
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_queue_alloca(
- device, queue_affinity, wait_semaphore_list, alloca_semaphore_list,
- IREE_HAL_ALLOCATOR_POOL_DEFAULT, target_params, length,
- &target_buffer);
- }
-
- // Queue the file read into the target buffer.
- if (iree_status_is_ok(status)) {
- switch (source_entry->type) {
- case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT:
- status = iree_hal_device_queue_fill(
- device, queue_affinity, alloca_semaphore_list,
- signal_semaphore_list, target_buffer, 0, length,
- source_entry->storage.splat.pattern,
- source_entry->storage.splat.pattern_length);
- break;
- case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE:
- status = iree_hal_device_queue_read(
- device, queue_affinity, alloca_semaphore_list,
- signal_semaphore_list, source_file,
- source_entry->storage.file.offset + source_offset, target_buffer,
- 0, length, 0);
- break;
- default:
- status =
- iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
- "reads not supported from parameters of type %d",
- (int)source_entry->type);
- break;
- }
- }
-
- iree_hal_semaphore_release(temporary_semaphore);
- }
-
- iree_hal_file_release(source_file);
- if (iree_status_is_ok(status)) {
- IREE_ASSERT_NE(target_buffer, NULL);
- *out_target_buffer = target_buffer;
- } else {
- iree_hal_buffer_release(target_buffer);
- }
- IREE_TRACE_ZONE_END(z0);
- return status;
-}
-
-static iree_status_t iree_io_parameter_index_provider_read(
- iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_t* target_buffer,
- iree_device_size_t target_offset, iree_device_size_t length) {
- iree_io_parameter_index_provider_t* provider =
- iree_io_parameter_index_provider_cast(base_provider);
- IREE_TRACE_ZONE_BEGIN(z0);
-
- // Lookup the parameter metadata and get its backing file.
- const iree_io_parameter_index_entry_t* source_entry = NULL;
- iree_hal_file_t* source_file = NULL; // retained, NULL if splat
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_io_parameter_index_provider_resolve(
- provider, device, queue_affinity, source_scope, source_key,
- IREE_HAL_MEMORY_ACCESS_READ, &source_entry, &source_file));
-
- // Validate the parameter range is in-bounds.
- iree_status_t status = iree_io_validate_parameter_range(
- IREE_HAL_MEMORY_ACCESS_READ, source_entry, source_offset, length);
-
- // Queue the file read into the target buffer.
- if (iree_status_is_ok(status)) {
- switch (source_entry->type) {
- case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT:
- status = iree_hal_device_queue_fill(
- device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
- target_buffer, 0, length, source_entry->storage.splat.pattern,
- source_entry->storage.splat.pattern_length);
- break;
- case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE:
- status = iree_hal_device_queue_read(
- device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
- source_file, source_entry->storage.file.offset + source_offset,
- target_buffer, target_offset, length, 0);
- break;
- default:
- status =
- iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
- "reads not supported from parameters of type %d",
- (int)source_entry->type);
- break;
- }
- }
-
- iree_hal_file_release(source_file);
-
- IREE_TRACE_ZONE_END(z0);
- return status;
-}
-
-static iree_status_t iree_io_parameter_index_provider_write(
- iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
- iree_string_view_t target_scope, iree_string_view_t target_key,
- uint64_t target_offset, iree_device_size_t length) {
- iree_io_parameter_index_provider_t* provider =
- iree_io_parameter_index_provider_cast(base_provider);
- IREE_TRACE_ZONE_BEGIN(z0);
-
- // Lookup the parameter metadata and get its backing file.
- const iree_io_parameter_index_entry_t* target_entry = NULL;
- iree_hal_file_t* target_file = NULL; // retained
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_io_parameter_index_provider_resolve(
- provider, device, queue_affinity, target_scope, target_key,
- IREE_HAL_MEMORY_ACCESS_WRITE, &target_entry, &target_file));
-
- // Validate the parameter range is in-bounds.
- iree_status_t status = iree_io_validate_parameter_range(
- IREE_HAL_MEMORY_ACCESS_WRITE, target_entry, target_offset, length);
-
- // Queue the file write from the source buffer.
- if (iree_status_is_ok(status)) {
- switch (target_entry->type) {
- case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE:
- status = iree_hal_device_queue_write(
- device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
- source_buffer, source_offset, target_file,
- target_entry->storage.file.offset + target_offset, length, 0);
- break;
- default:
- status =
- iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
- "writes not supported to parameters of type %d",
- (int)target_entry->type);
- break;
- }
- }
-
- iree_hal_file_release(target_file);
-
- IREE_TRACE_ZONE_END(z0);
- return status;
-}
-
-typedef iree_status_t(IREE_API_PTR* iree_io_parameter_index_file_operation_t)(
- iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
- iree_device_size_t buffer_offset, iree_device_size_t length,
- uint32_t flags);
-
-// Returns the index of the smallest value in |values|.
-// Linear scan as the number of values is expected to be small.
-static iree_host_size_t iree_io_select_timeline_bucket(iree_host_size_t count,
- const uint64_t* values) {
- IREE_ASSERT_GT(count, 0);
- uint64_t smallest_value = values[0];
- iree_host_size_t smallest_index = 0;
- for (iree_host_size_t i = 1; i < count; ++i) {
- if (values[i] < smallest_value) {
- smallest_value = values[i];
- smallest_index = i;
- }
- }
- return smallest_index;
-}
-
-static iree_status_t iree_io_parameter_index_provider_gather_scatter(
- iree_io_parameter_index_provider_t* provider, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t scope, iree_hal_buffer_t* buffer, iree_host_size_t count,
- iree_io_parameter_enumerator_t enumerator, iree_hal_memory_access_t access,
- iree_io_parameter_index_file_operation_t operation) {
- // Decide how many operations we'll keep in-flight at a time. Each concurrent
- // stream of operations requires its own semaphore.
- //
- // NOTE: we expect count == 0 and count == 1 to have been handled by callers
- // and assume that if we've hit this method we're doing something significant
- // and it's worth it to do all this.
- const iree_host_size_t concurrency =
- iree_min(count, provider->max_concurrent_operations);
-
- // Distribute operations over each timeline based on how much
- // I/O is done. It's possible for pathologically bad latency if there are
- // large and small operations interleaved as all large operations may end up
- // serialized on one timeline and all small ones on the other.
- // We distribute by tracking the total bytes outstanding on each timeline and
- // always placing the next operation on the one with the fewest. This assumes
- // that all I/O happens at roughly the same speed but if parameters come from
- // different files on different devices that may not be the case. It's better
- // than doing nothing, though.
- uint64_t* timeline_bytes_outstanding =
- (uint64_t*)iree_alloca(concurrency * sizeof(uint64_t));
- memset(timeline_bytes_outstanding, 0,
- concurrency * sizeof(*timeline_bytes_outstanding));
-
- // Allocate one semaphore per concurrent timeline.
- IREE_TRACE_ZONE_BEGIN_NAMED(
- z_init, "iree_io_parameter_index_provider_semaphore_pool_initialize");
- iree_hal_semaphore_t** timeline_semaphores =
- (iree_hal_semaphore_t**)iree_alloca(concurrency *
- sizeof(iree_hal_semaphore_t*));
- uint64_t* timeline_values =
- (uint64_t*)iree_alloca(concurrency * sizeof(uint64_t));
+ // Process each entry by enqueuing the appropriate operation.
iree_status_t status = iree_ok_status();
- for (iree_host_size_t i = 0; i < concurrency; ++i) {
- timeline_values[i] = 0ull;
- status = iree_hal_semaphore_create(device, timeline_values[i],
- &timeline_semaphores[i]);
- if (!iree_status_is_ok(status)) break;
- }
- IREE_TRACE_ZONE_END(z_init);
+ for (iree_host_size_t i = 0; i < count; ++i) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z_entry,
+ "iree_io_parameter_index_provider_load_entry");
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, i);
- // If any of the operations are splats or copies we'll record them into a
- // single command buffer allocated on first use. When done walking all entries
- // we'll submit the command buffer if it was used.
- iree_hal_command_buffer_t* transfer_command_buffer = NULL;
- uint64_t transfer_bytes = 0;
+ // Fetch the next parameter to process.
+ const iree_io_parameter_index_entry_t* source_entry = NULL;
+ iree_io_parameter_span_t span;
+ iree_hal_file_t* source_file = NULL; // retained, NULL if splat
+ status = iree_io_parameter_op_batch_resolve_entry(
+ &batch, source_scope, enumerator, i, IREE_HAL_MEMORY_ACCESS_READ,
+ &source_entry, &span, &source_file);
+ if (iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, source_entry->key.data,
+ source_entry->key.size);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, span.length);
+ }
- if (iree_status_is_ok(status)) {
- for (iree_host_size_t i = 0; i < count; ++i) {
- IREE_TRACE_ZONE_BEGIN_NAMED(
- z_entry, "iree_io_parameter_index_provider_gather_scatter_entry");
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, i);
+ // TODO(benvanik): refactor iree_io_parameter_index_provider_resolve so that
+ // it doesn't resolve the HAL file. Today if we hit the perfect case where
+ // all loads are able to be imported directly then we don't end up using the
+ // file we get back and instead have just wasted time/resources managing it.
+ // On CPU it's relatively cheap (a few mallocs) but on GPU it may require
+ // extremely expensive driver handling. Startup paths with parameters aren't
+ // usually critical, though, so it's (probably) fine today as-is.
- // Fetch the next parameter to copy and its buffer range.
- iree_string_view_t key;
- iree_io_parameter_span_t span;
- status = enumerator.fn(enumerator.user_data, i, &key, &span);
-
- // Lookup the parameter metadata and get its backing file.
- const iree_io_parameter_index_entry_t* entry = NULL;
- iree_hal_file_t* file = NULL; // retained, NULL if splat
- if (iree_status_is_ok(status)) {
- IREE_TRACE_ZONE_APPEND_TEXT(z_entry, key.data, key.size);
- status = iree_io_parameter_index_provider_resolve(
- provider, device, queue_affinity, scope, key, access, &entry,
- &file);
+ // Try first to reuse the file backing store directly as a buffer. This only
+ // works with specific file types and with specific target usage. The most
+ // common cases for this are when using parameters as staging sources (so
+ // host memory is ok) or on unified memory systems (where host memory is
+ // device memory) and the file was originally mapped. We could extend the
+ // conditions in which we use this with some better file handle helpers that
+ // allow us to map files that we already have open via other mechanisms
+ // (FILE, fd, etc).
+ iree_hal_buffer_t* target_buffer = NULL;
+ if (iree_status_is_ok(status) &&
+ source_entry->type == IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE &&
+ iree_io_file_handle_type(source_entry->storage.file.handle) ==
+ IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ iree_byte_span_t host_allocation =
+ iree_io_file_handle_primitive(source_entry->storage.file.handle)
+ .value.host_allocation;
+ iree_hal_external_buffer_t external_buffer = {
+ .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION,
+ .flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE,
+ .size = host_allocation.data_length,
+ .handle =
+ {
+ .host_allocation =
+ {
+ .ptr = host_allocation.data +
+ source_entry->storage.file.offset,
+ },
+ },
+ };
+ iree_hal_buffer_release_callback_t release_callback = {
+ .fn = iree_io_file_handle_buffer_release,
+ .user_data = source_entry->storage.file.handle,
+ };
+ iree_io_file_handle_retain(source_entry->storage.file.handle);
+ iree_status_t import_status = iree_hal_allocator_import_buffer(
+ iree_hal_device_allocator(device), target_params, &external_buffer,
+ release_callback, &target_buffer);
+ if (iree_status_is_ok(import_status)) {
+ // Import succeeded - issue a barrier to preserve the async timeline.
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, "import succeeded");
+ } else {
+ // Failed to import - that's ok as we'll just do the full allocate +
+ // read.
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, "import failed");
+ import_status = iree_status_ignore(import_status);
+ iree_io_file_handle_release(source_entry->storage.file.handle);
}
+ }
- // Validate the parameter range is in-bounds.
- if (iree_status_is_ok(status)) {
- status = iree_io_validate_parameter_range(
- access, entry, span.parameter_offset, span.length);
- }
+ // When the import path above fails we fall back to alloca + fill/read.
+ if (iree_status_is_ok(status) && !target_buffer) {
+ // Enqueue an allocation of the target buffer on a timeline.
+ // The next operation we enqueue will go on the same timeline.
+ status = iree_io_parameter_op_batch_enqueue_alloca(
+ &batch, IREE_HAL_ALLOCATOR_POOL_DEFAULT, target_params, span.length,
+ &target_buffer);
- // Queue the file operation or append to the fill command buffer.
+ // Enqueue the operation on the same timeline as the allocation.
if (iree_status_is_ok(status)) {
- switch (entry->type) {
+ switch (source_entry->type) {
case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT: {
- // Splats get routed to a fill command buffer that we'll submit at
- // the end. This avoids the need for us to check all of the
- // operations ahead of time at the cost of potentially acquiring
- // more semaphores than we need in cases where everything is a
- // splat. Splats are pretty much only useful for
- // testing/development, though, so it's ok to not be super efficient
- // here.
- if (!transfer_command_buffer) {
- status = iree_hal_command_buffer_create(
- device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
- IREE_HAL_COMMAND_CATEGORY_TRANSFER, queue_affinity, 0,
- &transfer_command_buffer);
- }
- if (iree_status_is_ok(status)) {
- transfer_bytes += span.length;
- status = iree_hal_command_buffer_fill_buffer(
- transfer_command_buffer, buffer, span.buffer_offset,
- span.length, entry->storage.splat.pattern,
- entry->storage.splat.pattern_length);
- }
+ IREE_ASSERT(!source_file);
+ status = iree_io_parameter_op_batch_enqueue_splat(
+ &batch, target_buffer, span.buffer_offset, span.length,
+ source_entry->storage.splat.pattern,
+ source_entry->storage.splat.pattern_length);
break;
}
case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE: {
- // Operations are tracked on as many timelines as there is
- // concurrency. We distribute operations onto timelines based on
- // which has the fewest outstanding I/O bytes.
- const iree_host_size_t timeline_index =
- iree_io_select_timeline_bucket(concurrency,
- timeline_bytes_outstanding);
- timeline_bytes_outstanding[timeline_index] += span.length;
- iree_hal_semaphore_t* timeline_semaphore =
- timeline_semaphores[timeline_index];
- uint64_t previous_timeline_value = timeline_values[timeline_index];
- uint64_t next_timeline_value = ++timeline_values[timeline_index];
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, (uint64_t)timeline_index);
-
- // The first wave of operations all wait on the provided wait
- // semaphores. All others wait on their own internal concurrent
- // timelines.
- iree_hal_semaphore_list_t entry_wait_semaphore_list;
- if (i < concurrency) {
- entry_wait_semaphore_list = wait_semaphore_list;
- } else {
- entry_wait_semaphore_list = (iree_hal_semaphore_list_t){
- .count = 1,
- .semaphores = &timeline_semaphore,
- .payload_values = &previous_timeline_value,
- };
- }
-
- // All operations signal their concurrency timelines and we'll put a
- // barrier at the end so that we can join them all.
- iree_hal_semaphore_list_t entry_signal_semaphore_list = {
- .count = 1,
- .semaphores = &timeline_semaphore,
- .payload_values = &next_timeline_value,
- };
-
- // Perform the operation.
- status =
- operation(device, queue_affinity, entry_wait_semaphore_list,
- entry_signal_semaphore_list, file,
- entry->storage.file.offset + span.parameter_offset,
- buffer, span.buffer_offset, span.length, 0);
+ IREE_ASSERT(source_file);
+ status = iree_io_parameter_op_batch_enqueue_file_read(
+ &batch, source_file,
+ source_entry->storage.file.offset + span.parameter_offset,
+ target_buffer, span.buffer_offset, span.length, 0);
break;
}
default: {
status = iree_make_status(
IREE_STATUS_FAILED_PRECONDITION,
- "gather/scatter not supported with parameters of type %d",
- (int)entry->type);
+ "load not supported with parameters of type %d",
+ (int)source_entry->type);
break;
}
}
-
- iree_hal_file_release(file);
-
- IREE_TRACE_ZONE_END(z_entry);
- if (!iree_status_is_ok(status)) break;
}
}
+
+ iree_hal_file_release(source_file);
+
+ // Emit the target buffer so the caller can handle it. The callee must
+ // retain it if they want to keep it live. We're allowed to emit out of
+ // order but are currently always 1:1 with enumeration (which may be useful
+ // in the future if we decide to make enumeration non-indexing).
+ if (iree_status_is_ok(status)) {
+ status = emitter.fn(emitter.user_data, i, target_buffer);
+ }
+ iree_hal_buffer_release(target_buffer);
+
+ IREE_TRACE_ZONE_END(z_entry);
+ if (!iree_status_is_ok(status)) break;
}
- // If any transfers were performed we'll need to submit the command buffer we
- // built during enumeration above. Order doesn't matter so we can issue it
- // alongside all of the other work by just appending it to a random timeline.
- if (iree_status_is_ok(status) && transfer_command_buffer) {
- IREE_TRACE_ZONE_BEGIN_NAMED(
- z_transfer, "iree_io_parameter_index_provider_gather_scatter_transfer");
- const iree_host_size_t timeline_index =
- iree_io_select_timeline_bucket(concurrency, timeline_bytes_outstanding);
- timeline_bytes_outstanding[timeline_index] += transfer_bytes;
- iree_hal_semaphore_t* timeline_semaphore =
- timeline_semaphores[timeline_index];
- uint64_t next_timeline_value = ++timeline_values[timeline_index];
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z_transfer, (uint64_t)timeline_index);
- iree_hal_semaphore_list_t transfer_signal_semaphore_list = {
- .count = 1,
- .semaphores = &timeline_semaphore,
- .payload_values = &next_timeline_value,
- };
- status = iree_hal_device_queue_execute(
- device, queue_affinity, wait_semaphore_list,
- transfer_signal_semaphore_list, 1, &transfer_command_buffer);
- IREE_TRACE_ZONE_END(z_transfer);
- }
- iree_hal_command_buffer_release(transfer_command_buffer);
+ // Flush any outstanding batch operations and end the batch.
+ status = iree_io_parameter_op_batch_end(&batch, status);
- // Join all concurrent timelines and continue the user-provided timeline.
- if (iree_status_is_ok(status)) {
- iree_hal_semaphore_list_t join_semaphore_list = {
- .count = concurrency,
- .semaphores = timeline_semaphores,
- .payload_values = timeline_values,
- };
- status = iree_hal_device_queue_barrier(
- device, queue_affinity, join_semaphore_list, signal_semaphore_list);
- }
-
- // Release temporary semaphores.
- IREE_TRACE_ZONE_BEGIN_NAMED(
- z_deinit, "iree_io_parameter_index_provider_semaphore_pool_deinitialize");
- for (iree_host_size_t i = 0; i < concurrency; ++i) {
- iree_hal_semaphore_release(timeline_semaphores[i]);
- }
- IREE_TRACE_ZONE_END(z_deinit);
-
+ IREE_TRACE_ZONE_END(z0);
return status;
}
-static iree_status_t iree_io_parameter_index_file_read(
- iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
- iree_device_size_t buffer_offset, iree_device_size_t length,
- uint32_t flags) {
- return iree_hal_device_queue_read(device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, file, file_offset,
- buffer, buffer_offset, length, flags);
-}
-
static iree_status_t iree_io_parameter_index_provider_gather(
iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
@@ -715,28 +839,76 @@
iree_io_parameter_index_provider_t* provider =
iree_io_parameter_index_provider_cast(base_provider);
IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
- iree_status_t status = iree_io_parameter_index_provider_gather_scatter(
- provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_scope, target_buffer, count, enumerator,
- IREE_HAL_MEMORY_ACCESS_READ, iree_io_parameter_index_file_read);
+ // Initialize the batch state.
+ iree_io_parameter_op_batch_t batch;
+ iree_io_parameter_op_batch_begin(provider, device, queue_affinity,
+ wait_semaphore_list, signal_semaphore_list,
+ &batch);
+
+ // Process each entry by enqueuing the appropriate operation.
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < count; ++i) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(
+ z_entry, "iree_io_parameter_index_provider_gather_entry");
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, i);
+
+ // Fetch the next parameter to process.
+ const iree_io_parameter_index_entry_t* source_entry = NULL;
+ iree_io_parameter_span_t span;
+ iree_hal_file_t* source_file = NULL; // retained, NULL if splat
+ status = iree_io_parameter_op_batch_resolve_entry(
+ &batch, source_scope, enumerator, i, IREE_HAL_MEMORY_ACCESS_READ,
+ &source_entry, &span, &source_file);
+ if (iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, source_entry->key.data,
+ source_entry->key.size);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, span.length);
+ }
+
+ // Enqueue the transfer/file operation.
+ if (iree_status_is_ok(status)) {
+ switch (source_entry->type) {
+ case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT: {
+ IREE_ASSERT(!source_file);
+ status = iree_io_parameter_op_batch_enqueue_splat(
+ &batch, target_buffer, span.buffer_offset, span.length,
+ source_entry->storage.splat.pattern,
+ source_entry->storage.splat.pattern_length);
+ break;
+ }
+ case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE: {
+ IREE_ASSERT(source_file);
+ status = iree_io_parameter_op_batch_enqueue_file_read(
+ &batch, source_file,
+ source_entry->storage.file.offset + span.parameter_offset,
+ target_buffer, span.buffer_offset, span.length, 0);
+ break;
+ }
+ default: {
+ status = iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "gather not supported with parameters of type %d",
+ (int)source_entry->type);
+ break;
+ }
+ }
+ }
+
+ iree_hal_file_release(source_file);
+
+ IREE_TRACE_ZONE_END(z_entry);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ // Flush any outstanding batch operations and end the batch.
+ status = iree_io_parameter_op_batch_end(&batch, status);
IREE_TRACE_ZONE_END(z0);
return status;
}
-static iree_status_t iree_io_parameter_index_file_write(
- iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer,
- iree_device_size_t buffer_offset, iree_device_size_t length,
- uint32_t flags) {
- return iree_hal_device_queue_write(
- device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
- buffer, buffer_offset, file, file_offset, length, flags);
-}
-
static iree_status_t iree_io_parameter_index_provider_scatter(
iree_io_parameter_provider_t* base_provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
@@ -747,11 +919,63 @@
iree_io_parameter_index_provider_t* provider =
iree_io_parameter_index_provider_cast(base_provider);
IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
- iree_status_t status = iree_io_parameter_index_provider_gather_scatter(
- provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, target_scope, source_buffer, count, enumerator,
- IREE_HAL_MEMORY_ACCESS_WRITE, iree_io_parameter_index_file_write);
+ // Initialize the batch state.
+ iree_io_parameter_op_batch_t batch;
+ iree_io_parameter_op_batch_begin(provider, device, queue_affinity,
+ wait_semaphore_list, signal_semaphore_list,
+ &batch);
+
+ // Process each entry by enqueuing the appropriate operation.
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < count; ++i) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(
+ z_entry, "iree_io_parameter_index_provider_scatter_entry");
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, i);
+
+ // Fetch the next parameter to process.
+ const iree_io_parameter_index_entry_t* target_entry = NULL;
+ iree_io_parameter_span_t span;
+ iree_hal_file_t* target_file = NULL; // retained, NULL if splat
+ status = iree_io_parameter_op_batch_resolve_entry(
+ &batch, target_scope, enumerator, i, IREE_HAL_MEMORY_ACCESS_WRITE,
+ &target_entry, &span, &target_file);
+ if (iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z_entry, target_entry->key.data,
+ target_entry->key.size);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z_entry, span.length);
+ }
+
+ // Enqueue the transfer/file operation.
+ if (iree_status_is_ok(status)) {
+ switch (target_entry->type) {
+ case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE: {
+ IREE_ASSERT(target_file);
+ status = iree_io_parameter_op_batch_enqueue_file_write(
+ &batch, source_buffer, span.buffer_offset, target_file,
+ target_entry->storage.file.offset + span.parameter_offset,
+ span.length, 0);
+ break;
+ }
+ default: {
+ status = iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "scatter not supported with parameters of type %d",
+ (int)target_entry->type);
+ break;
+ }
+ }
+ }
+
+ iree_hal_file_release(target_file);
+
+ IREE_TRACE_ZONE_END(z_entry);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ // Flush any outstanding batch operations and end the batch.
+ status = iree_io_parameter_op_batch_end(&batch, status);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -763,8 +987,6 @@
.notify = iree_io_parameter_index_provider_notify,
.query_support = iree_io_parameter_index_provider_query_support,
.load = iree_io_parameter_index_provider_load,
- .read = iree_io_parameter_index_provider_read,
- .write = iree_io_parameter_index_provider_write,
.gather = iree_io_parameter_index_provider_gather,
.scatter = iree_io_parameter_index_provider_scatter,
};
diff --git a/runtime/src/iree/io/parameter_provider.c b/runtime/src/iree/io/parameter_provider.c
index 064c3aa..8014685 100644
--- a/runtime/src/iree/io/parameter_provider.c
+++ b/runtime/src/iree/io/parameter_provider.c
@@ -58,20 +58,34 @@
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_params_t target_params,
- iree_device_size_t length,
- iree_hal_buffer_t** IREE_RESTRICT out_target_buffer) {
+ iree_string_view_t source_scope, iree_hal_buffer_params_t target_params,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator,
+ iree_io_parameter_emitter_t emitter) {
IREE_ASSERT_ARGUMENT(provider);
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = provider->vtable->load(
provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_scope, source_key, source_offset,
- target_params, length, out_target_buffer);
+ signal_semaphore_list, source_scope, target_params, count, enumerator,
+ emitter);
IREE_TRACE_ZONE_END(z0);
return status;
}
+typedef struct {
+ iree_string_view_t key;
+ iree_io_parameter_span_t span;
+} iree_io_parameter_provider_single_enumerator_state_t;
+static iree_status_t iree_io_parameter_provider_single_enumerator(
+ void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
+ iree_io_parameter_span_t* out_span) {
+ IREE_ASSERT_EQ(i, 0);
+ iree_io_parameter_provider_single_enumerator_state_t* state =
+ (iree_io_parameter_provider_single_enumerator_state_t*)user_data;
+ *out_key = state->key;
+ *out_span = state->span;
+ return iree_ok_status();
+}
+
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_read(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
@@ -81,11 +95,24 @@
uint64_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(provider);
+ IREE_ASSERT_ARGUMENT(target_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
- iree_status_t status = provider->vtable->read(
+ iree_io_parameter_provider_single_enumerator_state_t enumerator_state = {
+ .key = source_key,
+ .span =
+ {
+ .parameter_offset = source_offset,
+ .buffer_offset = target_offset,
+ .length = length,
+ },
+ };
+ iree_io_parameter_enumerator_t enumerator = {
+ .fn = iree_io_parameter_provider_single_enumerator,
+ .user_data = &enumerator_state,
+ };
+ iree_status_t status = provider->vtable->gather(
provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_scope, source_key, source_offset,
- target_buffer, target_offset, length);
+ signal_semaphore_list, source_scope, target_buffer, 1, enumerator);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -99,11 +126,24 @@
iree_string_view_t target_scope, iree_string_view_t target_key,
uint64_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(provider);
+ IREE_ASSERT_ARGUMENT(source_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
- iree_status_t status = provider->vtable->write(
+ iree_io_parameter_provider_single_enumerator_state_t enumerator_state = {
+ .key = target_key,
+ .span =
+ {
+ .parameter_offset = target_offset,
+ .buffer_offset = source_offset,
+ .length = length,
+ },
+ };
+ iree_io_parameter_enumerator_t enumerator = {
+ .fn = iree_io_parameter_provider_single_enumerator,
+ .user_data = &enumerator_state,
+ };
+ iree_status_t status = provider->vtable->scatter(
provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_buffer, source_offset, target_scope,
- target_key, target_offset, length);
+ signal_semaphore_list, source_buffer, target_scope, 1, enumerator);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -116,6 +156,7 @@
iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
IREE_ASSERT_ARGUMENT(provider);
+ IREE_ASSERT_ARGUMENT(target_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
if (count == 0) {
@@ -124,19 +165,7 @@
z0, iree_hal_device_queue_barrier(device, queue_affinity,
wait_semaphore_list,
signal_semaphore_list));
- } else if (count == 1) {
- // One span is just a read.
- iree_string_view_t key = iree_string_view_empty();
- iree_io_parameter_span_t span = {0};
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, enumerator.fn(enumerator.user_data, 0, &key, &span));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_io_parameter_provider_read(
- provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_scope, key, span.parameter_offset,
- target_buffer, span.buffer_offset, span.length));
} else {
- // Full multi-span gather.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, provider->vtable->gather(provider, device, queue_affinity,
wait_semaphore_list, signal_semaphore_list,
@@ -155,6 +184,7 @@
iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
IREE_ASSERT_ARGUMENT(provider);
+ IREE_ASSERT_ARGUMENT(source_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
if (count == 0) {
@@ -163,19 +193,7 @@
z0, iree_hal_device_queue_barrier(device, queue_affinity,
wait_semaphore_list,
signal_semaphore_list));
- } else if (count == 1) {
- // One span is just a write.
- iree_string_view_t key = iree_string_view_empty();
- iree_io_parameter_span_t span = {0};
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, enumerator.fn(enumerator.user_data, 0, &key, &span));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_io_parameter_provider_write(
- provider, device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, source_buffer, span.buffer_offset,
- target_scope, key, span.parameter_offset, span.length));
} else {
- // Full multi-span scatter.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, provider->vtable->scatter(provider, device, queue_affinity,
wait_semaphore_list,
diff --git a/runtime/src/iree/io/parameter_provider.h b/runtime/src/iree/io/parameter_provider.h
index 76ce5da..7d964a4 100644
--- a/runtime/src/iree/io/parameter_provider.h
+++ b/runtime/src/iree/io/parameter_provider.h
@@ -74,9 +74,34 @@
IREE_API_EXPORT bool iree_io_parameter_provider_query_support(
iree_io_parameter_provider_t* provider, iree_string_view_t scope);
-// Loads a parameter from |provider| for use on |device|.
-// |source_scope| and |source_key| define the parameter and |target_params|
-// defines how the buffer is to be allocated.
+typedef iree_status_t(IREE_API_PTR* iree_io_parameter_enumerator_fn_t)(
+ void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
+ iree_io_parameter_span_t* out_span);
+
+typedef struct iree_io_parameter_enumerator_t {
+ // Callback function pointer.
+ iree_io_parameter_enumerator_fn_t fn;
+ // User data passed to the callback function. Unowned.
+ void* user_data;
+} iree_io_parameter_enumerator_t;
+
+typedef iree_status_t(IREE_API_PTR* iree_io_parameter_emitter_fn_t)(
+ void* user_data, iree_host_size_t i, iree_hal_buffer_t* buffer);
+
+typedef struct iree_io_parameter_emitter_t {
+ // Callback function pointer.
+ iree_io_parameter_emitter_fn_t fn;
+ // User data passed to the callback function. Unowned.
+ void* user_data;
+} iree_io_parameter_emitter_t;
+
+// Loads zero or more spans from |provider| into buffers for use on |device|.
+// The |enumerator| defines the source keys in |source_scope| and the offset and
+// length in the resulting buffer of each span. Multiple spans may reference the
+// same source parameter but behavior is undefined if multiple span target
+// ranges overlap. The provided |target_emitter| will be called for each
+// parameter and pass back the buffer backing the parameter (possibly a
+// subspan of a larger shared allocation).
//
// If the implementation is able to meet the expected |target_params| with an
// existing buffer it may be returned without a new allocation. If access allows
@@ -84,18 +109,19 @@
// other users within the same process or across processes.
//
// Implementations that have no optimized load/import path can implement this
-// with iree_hal_device_queue_alloca and iree_io_parameter_provider_read.
+// with a series of iree_hal_device_queue_alloca and
+// iree_io_parameter_provider_gather ops. Note that in such a case multiple
+// results may have the same underlying storage buffer.
//
-// Returns IREE_STATUS_NOT_FOUND if the parameter is not found.
+// Returns IREE_STATUS_NOT_FOUND if any parameter is not found.
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_load(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_params_t target_params,
- iree_device_size_t length,
- iree_hal_buffer_t** IREE_RESTRICT out_target_buffer);
+ iree_string_view_t source_scope, iree_hal_buffer_params_t target_params,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator,
+ iree_io_parameter_emitter_t emitter);
// Reads a parameter from |provider| for use on |device|.
// |source_scope| and |source_key| define the parameter to be read into
@@ -125,17 +151,6 @@
iree_string_view_t target_scope, iree_string_view_t target_key,
uint64_t target_offset, iree_device_size_t length);
-typedef iree_status_t(IREE_API_PTR* iree_io_parameter_enumerator_fn_t)(
- void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
- iree_io_parameter_span_t* out_span);
-
-typedef struct iree_io_parameter_enumerator_t {
- // Callback function pointer.
- iree_io_parameter_enumerator_fn_t fn;
- // User data passed to the callback function. Unowned.
- void* user_data;
-} iree_io_parameter_enumerator_t;
-
// Gathers zero or more spans from |provider| into the given |target_buffer|.
// The |enumerator| defines the source keys in |source_scope| and the offset and
// length in the |target_buffer| of each span. Multiple spans may reference the
@@ -186,28 +201,9 @@
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_params_t target_params,
- iree_device_size_t length,
- iree_hal_buffer_t** IREE_RESTRICT out_target_buffer);
-
- iree_status_t(IREE_API_PTR* read)(
- iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_string_view_t source_scope, iree_string_view_t source_key,
- uint64_t source_offset, iree_hal_buffer_t* target_buffer,
- iree_device_size_t target_offset, iree_device_size_t length);
-
- iree_status_t(IREE_API_PTR* write)(
- iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity,
- const iree_hal_semaphore_list_t wait_semaphore_list,
- const iree_hal_semaphore_list_t signal_semaphore_list,
- iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
- iree_string_view_t target_scope, iree_string_view_t target_key,
- uint64_t target_offset, iree_device_size_t length);
+ iree_string_view_t source_scope, iree_hal_buffer_params_t target_params,
+ iree_host_size_t count, iree_io_parameter_enumerator_t enumerator,
+ iree_io_parameter_emitter_t emitter);
iree_status_t(IREE_API_PTR* gather)(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
diff --git a/runtime/src/iree/modules/io/parameters/exports.inl b/runtime/src/iree/modules/io/parameters/exports.inl
index 6fb0125..6275faf 100644
--- a/runtime/src/iree/modules/io/parameters/exports.inl
+++ b/runtime/src/iree/modules/io/parameters/exports.inl
@@ -25,7 +25,7 @@
// clang-format off
EXPORT_FN("gather", iree_io_parameters_module_gather, rIrrrrrrr, v)
-EXPORT_FN("load", iree_io_parameters_module_load, rIrrrrIIiiI, r)
+EXPORT_FN("load", iree_io_parameters_module_load, rIrrrIiirrr, r)
EXPORT_FN("scatter", iree_io_parameters_module_scatter, rIrrrrrrr, v)
// clang-format on
diff --git a/runtime/src/iree/modules/io/parameters/module.c b/runtime/src/iree/modules/io/parameters/module.c
index 4db143d..e2eb025 100644
--- a/runtime/src/iree/modules/io/parameters/module.c
+++ b/runtime/src/iree/modules/io/parameters/module.c
@@ -127,55 +127,6 @@
return (iree_device_size_t)value;
}
-//===----------------------------------------------------------------------===//
-// Exported functions
-//===----------------------------------------------------------------------===//
-
-IREE_VM_ABI_EXPORT(iree_io_parameters_module_load, //
- iree_io_parameters_module_state_t, //
- rIrrrrIIiiI, r) {
- iree_hal_device_t* device = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
- iree_hal_queue_affinity_t queue_affinity =
- (iree_hal_queue_affinity_t)args->i1;
- iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
- iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
- iree_vm_buffer_t* source_scope = NULL;
- IREE_RETURN_IF_ERROR(
- iree_vm_buffer_check_deref_or_null(args->r4, &source_scope));
- iree_vm_buffer_t* source_key = NULL;
- IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &source_key));
- uint64_t source_offset = args->i6;
- iree_hal_queue_affinity_t target_queue_affinity =
- (iree_hal_queue_affinity_t)args->i7;
- iree_hal_memory_type_t target_memory_types = (iree_hal_memory_type_t)args->i8;
- iree_hal_buffer_usage_t target_buffer_usage =
- (iree_hal_buffer_usage_t)args->i9;
- iree_device_size_t length = iree_hal_cast_device_size(args->i10);
-
- iree_io_parameter_provider_t* provider = NULL;
- IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
- IREE_IO_PARAMETERS_MODULE_CAST(module),
- iree_vm_buffer_as_string(source_scope), &provider));
-
- const iree_hal_buffer_params_t target_params = {
- .type = target_memory_types,
- .usage = target_buffer_usage,
- .queue_affinity = target_queue_affinity,
- };
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_io_parameter_provider_load(
- provider, device, queue_affinity,
- iree_hal_fence_semaphore_list(wait_fence),
- iree_hal_fence_semaphore_list(signal_fence),
- iree_vm_buffer_as_string(source_scope),
- iree_vm_buffer_as_string(source_key), source_offset, target_params,
- length, &target_buffer));
-
- rets->r0 = iree_hal_buffer_move_ref(target_buffer);
- return iree_ok_status();
-}
-
typedef struct iree_io_parameters_string_entry_t {
uint32_t offset;
uint32_t length;
@@ -292,6 +243,86 @@
return iree_ok_status();
}
+static iree_status_t iree_io_parameters_vm_list_emitter(
+ void* user_data, iree_host_size_t i, iree_hal_buffer_t* buffer) {
+ iree_vm_list_t* list = (iree_vm_list_t*)user_data;
+ return iree_vm_list_set_buffer_retain(list, i, buffer);
+}
+
+//===----------------------------------------------------------------------===//
+// Exported functions
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_io_parameters_module_load, //
+ iree_io_parameters_module_state_t, //
+ rIrrrIiirrr, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_vm_buffer_t* source_scope = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref_or_null(args->r4, &source_scope));
+ iree_hal_queue_affinity_t target_queue_affinity =
+ (iree_hal_queue_affinity_t)args->i5;
+ iree_hal_memory_type_t target_memory_types = (iree_hal_memory_type_t)args->i6;
+ iree_hal_buffer_usage_t target_buffer_usage =
+ (iree_hal_buffer_usage_t)args->i7;
+ iree_vm_buffer_t* key_table = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r8, &key_table));
+ iree_vm_buffer_t* key_data = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r9, &key_data));
+ iree_vm_buffer_t* spans = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r10, &spans));
+
+ iree_io_parameter_provider_t* provider = NULL;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_module_resolve_provider(
+ IREE_IO_PARAMETERS_MODULE_CAST(module),
+ iree_vm_buffer_as_string(source_scope), &provider));
+
+ iree_io_parameters_indirect_args_t enumerator_args;
+ IREE_RETURN_IF_ERROR(iree_io_parameters_prepare_indirect_args(
+ key_table, key_data, spans, &enumerator_args));
+ iree_io_parameter_enumerator_t enumerator = {
+ .fn = iree_io_parameters_indirect_enumerator,
+ .user_data = &enumerator_args,
+ };
+
+ iree_vm_list_t* target_buffers = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ iree_vm_make_ref_type_def(iree_hal_buffer_type()), enumerator_args.count,
+ state->host_allocator, &target_buffers));
+ iree_status_t status =
+ iree_vm_list_resize(target_buffers, enumerator_args.count);
+ iree_io_parameter_emitter_t emitter = {
+ .fn = iree_io_parameters_vm_list_emitter,
+ .user_data = target_buffers,
+ };
+
+ if (iree_status_is_ok(status)) {
+ const iree_hal_buffer_params_t target_params = {
+ .type = target_memory_types,
+ .usage = target_buffer_usage,
+ .queue_affinity = target_queue_affinity,
+ };
+ status = iree_io_parameter_provider_load(
+ provider, device, queue_affinity,
+ iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence),
+ iree_vm_buffer_as_string(source_scope), target_params,
+ enumerator_args.count, enumerator, emitter);
+ }
+
+ if (iree_status_is_ok(status)) {
+ rets->r0 = iree_vm_list_move_ref(target_buffers);
+ } else {
+ iree_vm_list_release(target_buffers);
+ }
+ return status;
+}
+
IREE_VM_ABI_EXPORT(iree_io_parameters_module_gather, //
iree_io_parameters_module_state_t, //
rIrrrrrrr, v) {
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 35d1544..83245cb 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -67,7 +67,7 @@
IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r);
IREE_VM_ABI_DEFINE_SHIM(rIrrrIrIIi, v);
IREE_VM_ABI_DEFINE_SHIM(rIrrrrrrr, v);
-IREE_VM_ABI_DEFINE_SHIM(rIrrrrIIiiI, r);
+IREE_VM_ABI_DEFINE_SHIM(rIrrrIiirrr, r);
IREE_VM_ABI_DEFINE_SHIM(rIrrr, v);
IREE_VM_ABI_DEFINE_SHIM(rIrrCrD, v);
IREE_VM_ABI_DEFINE_SHIM(CrID, r);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 5519543..af75456 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -449,18 +449,18 @@
iree_vm_ref_t r8;
});
-IREE_VM_ABI_FIXED_STRUCT(rIrrrrIIiiI, {
+IREE_VM_ABI_FIXED_STRUCT(rIrrrIiirrr, {
iree_vm_ref_t r0;
int64_t i1;
iree_vm_ref_t r2;
iree_vm_ref_t r3;
iree_vm_ref_t r4;
- iree_vm_ref_t r5;
- int64_t i6;
- int64_t i7;
- int32_t i8;
- int32_t i9;
- int64_t i10;
+ int64_t i5;
+ int32_t i6;
+ int32_t i7;
+ iree_vm_ref_t r8;
+ iree_vm_ref_t r9;
+ iree_vm_ref_t r10;
});
IREE_VM_ABI_FIXED_STRUCT(rIrrr, {
@@ -674,7 +674,7 @@
IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r);
IREE_VM_ABI_DECLARE_SHIM(rIrrrIrIIi, v);
IREE_VM_ABI_DECLARE_SHIM(rIrrrrrrr, v);
-IREE_VM_ABI_DECLARE_SHIM(rIrrrrIIiiI, r);
+IREE_VM_ABI_DECLARE_SHIM(rIrrrIiirrr, r);
IREE_VM_ABI_DECLARE_SHIM(rIrrr, v);
IREE_VM_ABI_DECLARE_SHIM(rIrrCrD, v);
IREE_VM_ABI_DECLARE_SHIM(CrID, r);