Switching external resources to be device-local only. (#14016)
Previously all external resources (results returned by an invocation)
were made host-visible and mappable and this prevented the use of
queue-ordered allocations in CUDA as memory pools cannot service memory
with associated host pointers. Depending on device the host-visible
memory could also be much slower to access (or have more potential
pitfalls with page management) vs pinned device-local memory and this
got worse once we started doing more dispatches in-place on the results.
Now all external buffers are by default allocated as device-local. Users
will need to manually stage the buffers and otherwise they'll remain
on-device. For externalized state this is a good thing as it means we'll
keep state on device automatically. A temporary flag has been added to
revert to the old mappable behavior with
`--iree-stream-external-resources-mappable=true`. Note that some devices
(like CPU) will always allow mapping even if not requested and users can
avoid the copies by checking before performing the transfers.
GPT2 CUDA post-change with alloca and no caching allocator enabled
(~5us/invocation allocation overhead):

diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
index 02785f5..c9fd8df 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
@@ -9,8 +9,6 @@
#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -43,39 +41,5 @@
});
}
-// static
-LogicalResult HALConversionTarget::applyDefaultBufferRewrite(
- Operation *srcOp, ValueRange operands, StringRef dstOpName,
- TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
- OperationState state{srcOp->getLoc(), dstOpName};
- state.addAttributes(srcOp->getAttrs());
-
- for (auto [srcOperand, dstOperand] :
- llvm::zip_equal(srcOp->getOperands(), operands)) {
- // Check that any type that should have been mapped to buffer view was.
- // This is just to catch conflicts in type conversions that may sneak in
- // during development.
- assert(
- (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
- dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
- "expect that tensors have been mapped to buffer views");
- state.addOperands({dstOperand});
- }
- for (auto resultType : srcOp->getResultTypes()) {
- if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
- state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
- } else {
- // Normal pass-through result.
- if (failed(typeConverter.convertType(resultType, state.types))) {
- return failure();
- }
- }
- }
-
- auto *dstOp = rewriter.create(state);
- rewriter.replaceOp(srcOp, dstOp->getResults());
- return success();
-}
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
index b41dd1f..fd3d489 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
@@ -8,7 +8,6 @@
#define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -22,47 +21,6 @@
class HALConversionTarget : public ConversionTarget {
public:
HALConversionTarget(MLIRContext *context, TypeConverter &typeConverter);
-
- // Attempts to rewrite an op that may use tensor values into an op using HAL
- // buffers. See HALOpConversion for more information.
- static LogicalResult
- applyDefaultBufferRewrite(Operation *srcOp, ValueRange operands,
- StringRef dstOpName, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-};
-
-// HAL tensor-to-buffer conversion utility.
-// This can be used by dialects to model custom op conversion from a dialect
-// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
-// during conversion the source values will be TensorType and the target values
-// will be IREE::HAL::BufferTypes. Any static information available about the
-// tensor (such as static dimensions, element type, layout, etc) are extracted
-// here and lowered as expanded values.
-//
-// The ABI is currently very basic and will change with the introduction of more
-// dynamic shape logic.
-//
-// Source:
-// my.tensor_op(%arg0 : tensor<2x4xf32>)
-// Target:
-// %arg0_view = hal.buffer_view.create %arg0, ...
-// my.buffer_op(%arg0_view : !hal.buffer_view)
-template <typename SRC, typename DST>
-class HALOpConversion : public OpConversionPattern<SRC> {
-public:
- HALOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}
-
- LogicalResult
- matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- return HALConversionTarget::applyDefaultBufferRewrite(
- srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter,
- rewriter);
- }
-
-protected:
- TypeConverter &typeConverter;
};
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 9c7ebe0..d0a809e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -23,6 +24,14 @@
namespace mlir {
namespace iree_compiler {
+static llvm::cl::opt<bool> clExternalResourcesMappable(
+ "iree-stream-external-resources-mappable",
+ llvm::cl::desc("Allocates external resources as host-visible and mappable. "
+ "This can degrade performance and introduce allocation "
+ "overhead and staging buffers for readback on the host "
+ "should be managed by the calling application instead."),
+ llvm::cl::init(false));
+
namespace {
static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
@@ -263,17 +272,21 @@
default:
break;
case IREE::Stream::Lifetime::External:
- // #yolo; these come from/go to outside the program.
- // Today we assume they are device-local|host-visible just for
- // practical purposes but that does not have to be true. We really
- // want this to be something we analyze and handle on the edges
- // (transferring devices/etc if needed).
- memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
- IREE::HAL::MemoryTypeBitfield::HostVisible;
- // NOTE: we may not map it but users may after they get them back.
- // Another reason we should annotate this - having a buffer be
- // mappable is potentially expensive (may get a 2nd copy in memory!).
- bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+ if (clExternalResourcesMappable) {
+ // #yolo; these come from/go to outside the program.
+ // Today we assume they are device-local|host-visible just for
+ // practical purposes but that does not have to be true. We really
+ // want this to be something we analyze and handle on the edges
+ // (transferring devices/etc if needed).
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
+ IREE::HAL::MemoryTypeBitfield::HostVisible;
+ // NOTE: we may not map it but users may after they get them back.
+ // Another reason we should annotate this - having a buffer be
+ // mappable is potentially expensive (may get a 2nd copy in memory!).
+ bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+ } else {
+ memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ }
break;
}
return success();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 2aca6c5..d45978d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -80,8 +80,8 @@
%arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}
// CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator>
- // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal")
- // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}Mapping{{.+}}")
+ // CHECK-SAME: type("DeviceVisible|DeviceLocal")
+ // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
// CHECK-SAME: : !hal.buffer{%c16}
%result_resource = stream.resource.alloc uninitialized : !stream.resource<external>{%c16}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index e036755..1fa81ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -307,7 +307,27 @@
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::TensorImportOp op) {
- removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+ auto targetType =
+ llvm::cast<IREE::Stream::ResourceType>(op.getResult().getType());
+ switch (targetType.getLifetime()) {
+ default:
+ case IREE::Stream::Lifetime::External:
+ removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+ break;
+ case IREE::Stream::Lifetime::Staging:
+ removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
+ NOT_STAGING_WRITE);
+ break;
+ case IREE::Stream::Lifetime::Transient:
+ removeAssumedBits(NOT_MUTATED);
+ break;
+ case IREE::Stream::Lifetime::Variable:
+ removeAssumedBits(NOT_MUTATED | NOT_GLOBAL_READ | NOT_GLOBAL_WRITE);
+ break;
+ case IREE::Stream::Lifetime::Constant:
+ removeAssumedBits(NOT_CONSTANT);
+ break;
+ }
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
@@ -497,7 +517,6 @@
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
-
auto &beforeUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op.getBeforeBody()->getArgument(operandIdx)),
@@ -510,13 +529,11 @@
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
-
auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op->getParentOp()->getResult(operandIdx - 1)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
-
if (auto whileOp =
dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
auto value = Position::forValue(
@@ -532,14 +549,12 @@
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
-
auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op->getParentOp()->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
}
-
if (auto whileOp =
dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
auto value =
@@ -589,7 +604,33 @@
removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE);
})
.Case([&](IREE::Stream::TensorExportOp op) {
- removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+ auto sourceType =
+ llvm::cast<IREE::Stream::ResourceType>(op.getSource().getType());
+ switch (sourceType.getLifetime()) {
+ default:
+ case IREE::Stream::Lifetime::External:
+ removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+ break;
+ case IREE::Stream::Lifetime::Staging:
+ removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
+ NOT_STAGING_WRITE | NOT_TRANSFER_READ |
+ NOT_TRANSFER_WRITE);
+ break;
+ case IREE::Stream::Lifetime::Transient:
+ removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
+ NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
+ NOT_DISPATCH_WRITE);
+ break;
+ case IREE::Stream::Lifetime::Variable:
+ removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
+ NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
+ NOT_DISPATCH_WRITE);
+ break;
+ case IREE::Stream::Lifetime::Constant:
+ removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_READ |
+ NOT_DISPATCH_READ);
+ break;
+ }
})
.Case([&](IREE::Stream::TensorTraceOp op) {
removeAssumedBits(NOT_STAGING_READ);
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
index 0644bda..4dcde8c 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Modules/Check/IR",
"@llvm-project//mlir:Pass",
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
index 582a6ad..c55d771 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
@@ -21,6 +21,7 @@
MLIRPass
MLIRTransforms
iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::VM::Conversion
iree::compiler::Modules::Check::IR
PUBLIC
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
index 82da66b..10cdbb3 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -7,6 +7,8 @@
#include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Modules/Check/IR/CheckOps.h"
#include "mlir/Pass/Pass.h"
@@ -60,17 +62,90 @@
context, importSymbols, typeConverter, "check.expect_almost_eq");
}
+// Attempts to rewrite an op that may use tensor values into an op using HAL
+// buffers.
+static LogicalResult applyDefaultCheckBufferRewrite(
+ Operation *srcOp, ValueRange operands, StringRef dstOpName,
+ TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+ OperationState state{srcOp->getLoc(), dstOpName};
+ state.addAttributes(srcOp->getAttrs());
+
+ // Add device argument.
+ Value device = rewriter.create<IREE::HAL::ExSharedDeviceOp>(srcOp->getLoc());
+ state.addOperands({device});
+
+ for (auto [srcOperand, dstOperand] :
+ llvm::zip_equal(srcOp->getOperands(), operands)) {
+ // Check that any type that should have been mapped to buffer view was.
+ // This is just to catch conflicts in type conversions that may sneak in
+ // during development.
+ assert(
+ (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
+ dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
+ "expect that tensors have been mapped to buffer views");
+ state.addOperands({dstOperand});
+ }
+ for (auto resultType : srcOp->getResultTypes()) {
+ if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
+ state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
+ } else {
+ // Normal pass-through result.
+ if (failed(typeConverter.convertType(resultType, state.types))) {
+ return failure();
+ }
+ }
+ }
+
+ auto *dstOp = rewriter.create(state);
+ rewriter.replaceOp(srcOp, dstOp->getResults());
+ return success();
+}
+
+// HAL tensor-to-buffer conversion utility.
+// This can be used by dialects to model custom op conversion from a dialect
+// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
+// during conversion the source values will be TensorType and the target values
+// will be IREE::HAL::BufferTypes. Any static information available about the
+// tensor (such as static dimensions, element type, layout, etc) are extracted
+// here and lowered as expanded values.
+//
+// The ABI is currently very basic and will change with the introduction of more
+// dynamic shape logic.
+//
+// Source:
+// my.tensor_op(%arg0 : tensor<2x4xf32>)
+// Target:
+// %arg0_view = hal.buffer_view.create %arg0, ...
+// my.buffer_op(%arg0_view : !hal.buffer_view)
+template <typename SRC, typename DST>
+class HALCheckOpConversion : public OpConversionPattern<SRC> {
+public:
+ HALCheckOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return applyDefaultCheckBufferRewrite(srcOp, adaptor.getOperands(),
+ DST::getOperationName(),
+ typeConverter, rewriter);
+ }
+
+protected:
+ TypeConverter &typeConverter;
+};
+
void populateCheckToHALPatterns(MLIRContext *context,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
// The same op handles both tensors and buffer views.
- patterns
- .insert<HALOpConversion<IREE::Check::ExpectAllTrueOp,
- IREE::Check::ExpectAllTrueOp>,
- HALOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
- HALOpConversion<IREE::Check::ExpectAlmostEqOp,
- IREE::Check::ExpectAlmostEqOp>>(context,
- typeConverter);
+ patterns.insert<
+ HALCheckOpConversion<IREE::Check::ExpectAllTrueOp,
+ IREE::Check::ExpectAllTrueOp>,
+ HALCheckOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
+ HALCheckOpConversion<IREE::Check::ExpectAlmostEqOp,
+ IREE::Check::ExpectAlmostEqOp>>(context,
+ typeConverter);
}
} // namespace Check
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
index dff0294..e55f3d2 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
@@ -57,6 +57,7 @@
":IR",
":check_ops_gen",
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Modules/Check:check_imports",
"//compiler/src/iree/compiler/Modules/Check/Conversion",
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
index c3a8574..b0928ce 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
@@ -42,6 +42,7 @@
MLIRParser
MLIRTransforms
iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::VM::Conversion
iree::compiler::Modules::Check::Conversion
iree::compiler::Modules::Check::check_imports
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
index dbdb4e1..554baa6 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Modules/Check/IR/CheckDialect.h"
#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"
#include "iree/compiler/Modules/Check/IR/CheckOps.h"
@@ -57,6 +58,8 @@
CheckDialect::CheckDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<CheckDialect>()) {
+ context->loadDialect<IREE::HAL::HALDialect>();
+
addInterfaces<CheckToVmConversionInterface>();
addInterfaces<CheckToHalConversionInterface>();
#define GET_OP_LIST
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
index a651bfe..69cfda7 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
@@ -24,7 +24,7 @@
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto rhs = rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue());
- rewriter.replaceOpWithNewOp<DstOp>(op, op.getLhs(), rhs);
+ rewriter.replaceOpWithNewOp<DstOp>(op, op.getDevice(), op.getLhs(), rhs);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
index 9d0b1b3..59c2236 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
@@ -36,7 +36,6 @@
let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)";
}
-
def CHECK_ExpectFalseOp : Op<CHECK_Dialect, "expect_false"> {
let summary = [{Checks that the operand is false}];
let description = [{
@@ -64,18 +63,24 @@
Issues a non-fatal failure if the verification fails.
```mlir
- check.expect_all_true(%arg0) : !hal.buffer_view
+ check.expect_all_true<%device>(%arg0) : !hal.buffer_view
check.expect_all_true(%arg1) : tensor<2x2xi32>
```
}];
- let arguments =
- (ins AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand);
+ let arguments = (ins
+ Optional<HAL_Device>:$device,
+ AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand
+ );
- let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)";
+ let assemblyFormat = [{
+ (`` `<` $device^ `>`)?
+ `` `(` $operand `)` attr-dict `:` type($operand)
+ }];
}
-def CHECK_ExpectEqOp : Op<CHECK_Dialect, "expect_eq", [SameTypeOperands]> {
+def CHECK_ExpectEqOp :
+ Op<CHECK_Dialect, "expect_eq", [AllTypesMatch<["lhs", "rhs"]>]> {
let summary = [{Checks that the tensor or buffer view operands are equal}];
let description = [{
Verifies that the operands are exactly equal.
@@ -88,11 +93,15 @@
}];
let arguments = (ins
- AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs,
- AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs
+ Optional<HAL_Device>:$device,
+ AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs,
+ AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs
);
- let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)";
+ let assemblyFormat = [{
+ (`` `<` $device^ `>`)?
+ `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)
+ }];
}
def CHECK_ExpectEqConstOp :
@@ -111,17 +120,21 @@
}];
let arguments = (ins
+ Optional<HAL_Device>:$device,
AnyTensor:$lhs,
ElementsAttr:$value
);
let hasCanonicalizer = 1;
- let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)";
+ let assemblyFormat = [{
+ (`` `<` $device^ `>`)?
+ `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs)
+ }];
}
def CHECK_ExpectAlmostEqOp :
- Op<CHECK_Dialect, "expect_almost_eq", [SameTypeOperands]> {
+ Op<CHECK_Dialect, "expect_almost_eq", [AllTypesMatch<["lhs", "rhs"]>]> {
let summary = [{Checks that the operands are almost equal}];
let description = [{
Verifies that the buffer view or tensor operands with float elements are
@@ -135,11 +148,15 @@
}];
let arguments = (ins
- AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs,
- AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs
+ Optional<HAL_Device>:$device,
+ AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs,
+ AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs
);
- let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)";
+ let assemblyFormat = [{
+ (`` `<` $device^ `>`)?
+ `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)
+ }];
}
def CHECK_ExpectAlmostEqConstOp :
@@ -160,13 +177,17 @@
}];
let arguments = (ins
+ Optional<HAL_Device>:$device,
TensorOf<[AnyFloat]>:$lhs,
ElementsAttr:$value
);
let hasCanonicalizer = 1;
- let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)";
+ let assemblyFormat = [{
+ (`` `<` $device^ `>`)?
+ `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs)
+ }];
}
#endif // IREE_MODULES_CHECK_DIALECT_CHECK_OPS
diff --git a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
index 67bae93..63b9d72 100644
--- a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
+++ b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
@@ -15,15 +15,18 @@
)
vm.import private optional @expect_all_true(
+ %device : !vm.ref<!hal.device>,
%operand : !vm.ref<!hal.buffer_view>,
)
vm.import private optional @expect_eq(
+ %device : !vm.ref<!hal.device>,
%lhs : !vm.ref<!hal.buffer_view>,
%rhs : !vm.ref<!hal.buffer_view>
)
vm.import private optional @expect_almost_eq(
+ %device : !vm.ref<!hal.device>,
%lhs : !vm.ref<!hal.buffer_view>,
%rhs : !vm.ref<!hal.buffer_view>
)
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index b53bcd0..a5e8788 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -623,7 +623,7 @@
// allocator is set on the device.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools &&
- !iree_any_bit_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
status = iree_hal_cuda2_memory_pools_alloca(
&device->memory_pools, device->dispatch_cu_stream, pool, params,
allocation_size, out_buffer);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 4aaba55..cf3bd7f 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -560,7 +560,7 @@
// allocator is set on the device.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools &&
- !iree_any_bit_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
status = iree_hal_cuda_memory_pools_alloca(&device->memory_pools,
device->stream, pool, params,
allocation_size, out_buffer);
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index 67f1947..7623fb5 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -197,6 +197,9 @@
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(), args.size(),
iree_allocator_system(), &inputs_));
+ iree_vm_ref_t device_ref = iree_hal_device_retain_ref(device_);
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_push_ref_move(inputs_.get(), &device_ref));
for (auto& arg : args) {
iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg.get());
IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_.get(), &arg_ref));
diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc
index b417eef..edbb9fe 100644
--- a/runtime/src/iree/modules/check/module.cc
+++ b/runtime/src/iree/modules/check/module.cc
@@ -155,6 +155,100 @@
"unsupported element type %s", element_type_str);
}
+static StatusOr<std::vector<vm::ref<iree_hal_buffer_view_t>>>
+TransferBuffersToHost(
+ iree_hal_device_t* device,
+ const iree::span<const vm::ref<iree_hal_buffer_view_t>> source_views) {
+ IREE_TRACE_SCOPE();
+
+ // If all buffers are already host-accessible we can skip the transfer.
+ std::vector<vm::ref<iree_hal_buffer_view_t>> target_views;
+ bool requires_transfer = false;
+ for (auto& source_view : source_views) {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(source_view.get());
+ if (!iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) ||
+ !iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer),
+ IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) {
+ requires_transfer = true;
+ }
+ }
+ if (!requires_transfer) {
+ for (auto& source_view : source_views) target_views.push_back(source_view);
+ return std::move(target_views);
+ }
+
+ vm::ref<iree_hal_command_buffer_t> command_buffer;
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+ device,
+ IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+ IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
+ IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY, 0,
+ &command_buffer));
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_begin(command_buffer.get()));
+
+ iree_hal_buffer_params_t target_params = {
+ /*.usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+ /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL,
+ /*.type=*/
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*.min_alignment=*/0,
+ };
+ for (size_t i = 0; i < source_views.size(); ++i) {
+ iree_hal_buffer_t* source_buffer =
+ iree_hal_buffer_view_buffer(source_views[i].get());
+ iree_device_size_t buffer_length =
+ iree_hal_buffer_byte_length(source_buffer);
+ vm::ref<iree_hal_buffer_t> target_buffer;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+ iree_hal_device_allocator(device), target_params, buffer_length,
+ &target_buffer));
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_copy_buffer(
+ command_buffer.get(), source_buffer, 0, target_buffer.get(), 0,
+ buffer_length));
+ vm::ref<iree_hal_buffer_view_t> target_view;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create_like(
+ target_buffer.get(), source_views[i].get(),
+ iree_hal_device_host_allocator(device), &target_view));
+ target_views.push_back(std::move(target_view));
+ }
+
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_end(command_buffer.get()));
+ vm::ref<iree_hal_semaphore_t> semaphore;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
+ vm::ref<iree_hal_fence_t> fence;
+ IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(
+ semaphore.get(), 1ull, iree_hal_device_host_allocator(device), &fence));
+ IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
+ device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+ iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_fence_wait(fence.get(), iree_infinite_timeout()));
+ return std::move(target_views);
+}
+
+static Status TransferToHost(iree_hal_device_t* device,
+ vm::ref<iree_hal_buffer_view_t>& buffer_view) {
+ IREE_TRACE_SCOPE();
+ IREE_ASSIGN_OR_RETURN(auto target_views,
+ TransferBuffersToHost(device, {buffer_view}));
+ buffer_view = std::move(target_views[0]);
+ return OkStatus();
+}
+
+static Status TransferToHost(iree_hal_device_t* device,
+ vm::ref<iree_hal_buffer_view_t>& buffer_view_a,
+ vm::ref<iree_hal_buffer_view_t>& buffer_view_b) {
+ IREE_TRACE_SCOPE();
+ IREE_ASSIGN_OR_RETURN(
+ auto target_views,
+ TransferBuffersToHost(device, {buffer_view_a, buffer_view_b}));
+ buffer_view_a = std::move(target_views[0]);
+ buffer_view_b = std::move(target_views[1]);
+ return OkStatus();
+}
+
// Per-context module state.
// This can contain "globals" and other arbitrary state.
//
@@ -177,7 +271,9 @@
return OkStatus();
}
- Status ExpectAllTrue(vm::ref<iree_hal_buffer_view_t> operand) {
+ Status ExpectAllTrue(vm::ref<iree_hal_device_t> device,
+ vm::ref<iree_hal_buffer_view_t> operand) {
+ IREE_RETURN_IF_ERROR(TransferToHost(device.get(), operand));
auto* view = operand.get();
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(view);
@@ -193,8 +289,10 @@
return OkStatus();
}
- Status ExpectEq(vm::ref<iree_hal_buffer_view_t> lhs_ref,
+ Status ExpectEq(vm::ref<iree_hal_device_t> device,
+ vm::ref<iree_hal_buffer_view_t> lhs_ref,
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
+ IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref));
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();
@@ -272,8 +370,10 @@
return OkStatus();
}
- Status ExpectAlmostEq(vm::ref<iree_hal_buffer_view_t> lhs_ref,
+ Status ExpectAlmostEq(vm::ref<iree_hal_device_t> device,
+ vm::ref<iree_hal_buffer_view_t> lhs_ref,
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
+ IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref));
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index ff5aa8e..40d8bc3 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -14,9 +14,10 @@
}
func.func @expect_all_true() {
+ %device = hal.ex.shared_device : !hal.device
%all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32>
%all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view
- check.expect_all_true(%all_true_view) : !hal.buffer_view
+ check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view
return
}
diff --git a/runtime/src/iree/modules/hal/types.c b/runtime/src/iree/modules/hal/types.c
index 0c7e0d7..52ce5a2 100644
--- a/runtime/src/iree/modules/hal/types.c
+++ b/runtime/src/iree/modules/hal/types.c
@@ -205,7 +205,7 @@
IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_retain(
iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_t* value) {
- iree_vm_ref_t value_ref;
+ iree_vm_ref_t value_ref = iree_vm_ref_null();
IREE_RETURN_IF_ERROR(
iree_vm_ref_wrap_assign(value, iree_hal_buffer_type(), &value_ref));
return iree_vm_list_set_ref_retain(list, i, &value_ref);
@@ -226,7 +226,7 @@
IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_view_retain(
iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_view_t* value) {
- iree_vm_ref_t value_ref;
+ iree_vm_ref_t value_ref = iree_vm_ref_null();
IREE_RETURN_IF_ERROR(
iree_vm_ref_wrap_assign(value, iree_hal_buffer_view_type(), &value_ref));
return iree_vm_list_set_ref_retain(list, i, &value_ref);
@@ -247,7 +247,7 @@
IREE_API_EXPORT iree_status_t iree_vm_list_set_fence_retain(
iree_vm_list_t* list, iree_host_size_t i, iree_hal_fence_t* value) {
- iree_vm_ref_t value_ref;
+ iree_vm_ref_t value_ref = iree_vm_ref_null();
IREE_RETURN_IF_ERROR(
iree_vm_ref_wrap_assign(value, iree_hal_fence_type(), &value_ref));
return iree_vm_list_set_ref_retain(list, i, &value_ref);
diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c
index ad5e674..2af3db4 100644
--- a/runtime/src/iree/tooling/run_module.c
+++ b/runtime/src/iree/tooling/run_module.c
@@ -246,6 +246,22 @@
"processing instrument data");
}
+ // Transfer outputs to the host so they can be processed. Only required when
+ // using full HAL device-based execution.
+ if (iree_status_is_ok(status) && device != NULL) {
+ iree_hal_buffer_params_t target_params = {
+ .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+ .access = IREE_HAL_MEMORY_ACCESS_ALL,
+ .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+ .min_alignment = 0,
+ };
+ status = iree_tooling_transfer_variant_list(
+ device, outputs, device_allocator, target_params,
+ /*wait_fence=*/NULL, /*signal_fence=*/NULL);
+ }
+
// Handle either printing/writing the outputs or checking them against
// expected values (basic pass/fail testing).
if (iree_status_is_ok(status)) {
diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c
index b21eada..70e2e77 100644
--- a/runtime/src/iree/tooling/vm_util.c
+++ b/runtime/src/iree/tooling/vm_util.c
@@ -324,6 +324,187 @@
return status;
}
+static bool iree_tooling_requires_buffer_transfer(
+ iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) {
+ return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer),
+ target_params.type) ||
+ !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer),
+ target_params.usage);
+}
+
+static iree_status_t iree_tooling_setup_buffer_transfer(
+ iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer,
+ iree_hal_allocator_t* target_allocator,
+ iree_hal_buffer_params_t target_params,
+ iree_hal_buffer_t** out_target_buffer) {
+ IREE_ASSERT_ARGUMENT(command_buffer);
+ IREE_ASSERT_ARGUMENT(source_buffer);
+ IREE_ASSERT_ARGUMENT(target_allocator);
+ IREE_ASSERT_ARGUMENT(out_target_buffer);
+ *out_target_buffer = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_allocator_allocate_buffer(
+ target_allocator, target_params,
+ iree_hal_buffer_allocation_size(source_buffer), &target_buffer));
+
+ iree_status_t status = iree_hal_command_buffer_copy_buffer(
+ command_buffer, source_buffer, 0, target_buffer, 0,
+ iree_hal_buffer_byte_length(source_buffer));
+
+ if (iree_status_is_ok(status)) {
+ *out_target_buffer = target_buffer;
+ } else {
+ iree_hal_buffer_release(target_buffer);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_tooling_submit_transfer(
+ iree_hal_device_t* device, iree_hal_fence_t* wait_fence,
+ iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_t status = iree_ok_status();
+
+ bool needs_wait = signal_fence == NULL;
+ if (needs_wait) {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_semaphore_create(device, 0ull, &semaphore));
+ status = iree_hal_fence_create_at(
+ semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence);
+ iree_hal_semaphore_release(semaphore);
+ } else {
+ iree_hal_fence_retain(signal_fence);
+ }
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_queue_execute(
+ device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer);
+ }
+
+ if (iree_status_is_ok(status) && needs_wait) {
+ status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout());
+ }
+
+ iree_hal_fence_release(signal_fence);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_tooling_transfer_variant_list(
+ iree_hal_device_t* device, iree_vm_list_t* list,
+ iree_hal_allocator_t* target_allocator,
+ iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence,
+ iree_hal_fence_t* signal_fence) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(list);
+ IREE_ASSERT_ARGUMENT(target_allocator);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // If all buffers are already host-accessible we can skip the transfer.
+ bool requires_transfer = false;
+ for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
+ iree_vm_ref_t value = iree_vm_ref_null();
+ IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value));
+ if (iree_hal_buffer_isa(value)) {
+ iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value);
+ if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) {
+ requires_transfer = true;
+ break;
+ }
+ } else if (iree_hal_buffer_view_isa(value)) {
+ iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value);
+ iree_hal_buffer_t* source_buffer =
+ iree_hal_buffer_view_buffer(source_view);
+ if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) {
+ requires_transfer = true;
+ break;
+ }
+ }
+ }
+ if (!requires_transfer) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_create(
+ device,
+ IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+ IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
+ IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity,
+ /*binding_capacity=*/0, &command_buffer));
+
+ iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
+ iree_vm_ref_t value = iree_vm_ref_null();
+ IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value));
+ if (iree_hal_buffer_isa(value)) {
+ iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value);
+ if (!iree_tooling_requires_buffer_transfer(source_buffer,
+ target_params)) {
+ // Already ok.
+ continue;
+ }
+ iree_hal_buffer_t* target_buffer = NULL;
+ status = iree_tooling_setup_buffer_transfer(
+ command_buffer, source_buffer, target_allocator, target_params,
+ &target_buffer);
+ if (!iree_status_is_ok(status)) break;
+ status = iree_vm_list_set_buffer_retain(list, i, target_buffer);
+ iree_hal_buffer_release(target_buffer);
+ if (!iree_status_is_ok(status)) break;
+ } else if (iree_hal_buffer_view_isa(value)) {
+ iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value);
+ iree_hal_buffer_t* source_buffer =
+ iree_hal_buffer_view_buffer(source_view);
+ if (!iree_tooling_requires_buffer_transfer(source_buffer,
+ target_params)) {
+ // Already ok.
+ continue;
+ }
+ iree_hal_buffer_t* target_buffer = NULL;
+ status = iree_tooling_setup_buffer_transfer(
+ command_buffer, source_buffer, target_allocator, target_params,
+ &target_buffer);
+ if (!iree_status_is_ok(status)) break;
+ iree_hal_buffer_view_t* target_view = NULL;
+ status = iree_hal_buffer_view_create_like(
+ target_buffer, source_view,
+ iree_hal_allocator_host_allocator(target_allocator), &target_view);
+ iree_hal_buffer_release(target_buffer);
+ if (!iree_status_is_ok(status)) break;
+ status = iree_vm_list_set_buffer_view_retain(list, i, target_view);
+ iree_hal_buffer_view_release(target_view);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_command_buffer_end(command_buffer);
+ }
+
+ if (iree_status_is_ok(status)) {
+ status = iree_tooling_submit_transfer(device, wait_fence,
+ target_params.queue_affinity,
+ command_buffer, signal_fence);
+ }
+
+ iree_hal_command_buffer_release(command_buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
#define IREE_PRINTVARIANT_CASE_I(SIZE, B, V) \
case IREE_VM_VALUE_TYPE_I##SIZE: \
return iree_string_builder_append_format( \
diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h
index e2a0311..bc9ca00 100644
--- a/runtime/src/iree/tooling/vm_util.h
+++ b/runtime/src/iree/tooling/vm_util.h
@@ -54,6 +54,16 @@
iree_hal_device_t* device, iree_hal_fence_t* wait_fence,
iree_hal_fence_t** out_signal_fence);
+// Transfers all buffers in |list| to ones using |target_params|.
+// If no |wait_fence| is provided then the transfer will begin immediately.
+// If no |signal_fence| is provided then the call will block until the transfer
+// completes.
+iree_status_t iree_tooling_transfer_variant_list(
+ iree_hal_device_t* device, iree_vm_list_t* list,
+ iree_hal_allocator_t* target_allocator,
+ iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence,
+ iree_hal_fence_t* signal_fence);
+
// Appends a variant list of VM scalars and buffers to |builder|.
// |list_name| will be printed alongside each element ordinal.
//
diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel
index 75c4563..11af40c 100644
--- a/tools/BUILD.bazel
+++ b/tools/BUILD.bazel
@@ -210,6 +210,7 @@
"//runtime/src/iree/modules/hal",
"//runtime/src/iree/tooling:device_util",
"//runtime/src/iree/tooling:trace_replay",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/tooling:yaml_util",
"//runtime/src/iree/vm",
"@com_github_yaml_libyaml//:yaml",
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 3cf3a0a..2445774 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -215,6 +215,7 @@
iree::modules::hal
iree::tooling::device_util
iree::tooling::trace_replay
+ iree::tooling::vm_util
iree::tooling::yaml_util
iree::vm
yaml
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 83b7343..758ae35 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -19,6 +19,7 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/trace_replay.h"
+#include "iree/tooling/vm_util.h"
#include "iree/tooling/yaml_util.h"
#include "iree/vm/api.h"
@@ -200,10 +201,8 @@
iree_hal_buffer_view_t* buffer_view,
enum iree_hal_memory_access_bits_t access,
iree_hal_buffer_mapping_t* mapping) {
- // Really validate host-local, not just host-visible: callers may rely on
- // host-coherency.
IREE_RETURN_IF_ERROR(
- validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_LOCAL));
+ validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE));
if (iree_hal_buffer_view_encoding_type(buffer_view) !=
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
@@ -1055,39 +1054,43 @@
replay->device, device_allocator, device_inputs, &host_inputs));
// Invoke the function to produce the actual result.
- iree_vm_list_t* device_outputs = NULL;
+ iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*initial_capacity=*/8,
- replay->host_allocator, &device_outputs));
+ replay->host_allocator, &outputs));
IREE_CHECK_OK(iree_vm_invoke(
replay->context, function, IREE_VM_INVOCATION_FLAG_NONE,
- /*policy=*/NULL, device_inputs, device_outputs, replay->host_allocator));
+ /*policy=*/NULL, device_inputs, outputs, replay->host_allocator));
iree_vm_list_release(device_inputs);
- // Get the device_actual_result from the device_outputs.
- iree_hal_buffer_view_t* device_actual_result;
- IREE_CHECK_OK(
- get_item_as_buffer_view(device_outputs, 0, &device_actual_result));
+ // Transfer device buffers to host buffers.
+ iree_hal_buffer_params_t host_params = {
+ .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+ .access = IREE_HAL_MEMORY_ACCESS_ALL,
+ .type =
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+ .min_alignment = 0,
+ };
+ IREE_CHECK_OK(iree_tooling_transfer_variant_list(
+ replay->device, outputs, device_allocator, host_params,
+ /*wait_fence=*/NULL, /*signal_fence=*/NULL));
- // Copy the results to a host local buffer to be able to map it.
- iree_hal_buffer_view_t* host_actual_result = NULL;
- IREE_CHECK_OK(copy_device_buffer_view_to_host(
- replay->device, device_allocator, device_actual_result,
- &host_actual_result));
+ // Get the actual result computed by the program.
+ iree_hal_buffer_view_t* actual_result;
+ IREE_CHECK_OK(get_item_as_buffer_view(outputs, 0, &actual_result));
- // Allocate host_expected_result with same shape as host_actual_result.
+ // Allocate host_expected_result with same shape as actual_result.
iree_hal_buffer_view_t* host_expected_result = NULL;
- IREE_CHECK_OK(allocate_host_buffer_view_like(replay->device, device_allocator,
- host_actual_result,
- &host_expected_result));
+ IREE_CHECK_OK(allocate_host_buffer_view_like(
+ replay->device, device_allocator, actual_result, &host_expected_result));
- // Check that host_actual_result and host_expected_result agree.
- iree_status_t status = check_matmul_results(
- file, host_inputs, host_actual_result, host_expected_result);
+ // Check that actual_result and host_expected_result agree.
+ iree_status_t status = check_matmul_results(file, host_inputs, actual_result,
+ host_expected_result);
- iree_vm_list_release(device_outputs); // releases device_actual_result
+ iree_vm_list_release(outputs); // releases actual_result
iree_vm_list_release(host_inputs);
- iree_hal_buffer_view_release(host_actual_result);
iree_hal_buffer_view_release(host_expected_result);
return status;
}
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index fa46810..b1b39dc 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -197,6 +197,21 @@
yaml_parser_delete(&parser);
+ // Transfer outputs to the host so they can be processed.
+ if (iree_status_is_ok(status) && replay.device != NULL) {
+ iree_hal_buffer_params_t target_params = {
+ .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+ .access = IREE_HAL_MEMORY_ACCESS_ALL,
+ .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+ .min_alignment = 0,
+ };
+ status = iree_tooling_transfer_variant_list(
+ replay.device, replay.outputs, iree_hal_device_allocator(replay.device),
+ target_params, /*wait_fence=*/NULL, /*signal_fence=*/NULL);
+ }
+
// Optionally process outputs from the replay session.
if (iree_status_is_ok(status)) {
if (FLAG_output_list().count == 0) {