Switching external resources to be device-local only. (#14016)

Previously all external resources (results returned by an invocation)
were made host-visible and mappable and this prevented the use of
queue-ordered allocations in CUDA as memory pools cannot service memory
with associated host pointers. Depending on device the host-visible
memory could also be much slower to access (or have more potential
pitfalls with page management) vs pinned device-local memory and this
got worse once we started doing more dispatches in-place on the results.

Now all external buffers are by default allocated as device-local. Users
will need to manually stage the buffers and otherwise they'll remain
on-device. For externalized state this is a good thing as it means we'll
keep state on device automatically. A temporary flag has been added to
revert to the old mappable behavior with
`--iree-stream-external-resources-mappable=true`. Note that some devices
(like CPU) will always allow mapping even if not requested and users can
avoid the copies by checking before performing the transfers.

GPT2 CUDA post-change with alloca and no caching allocator enabled
(~5us/invocation allocation overhead):

![image](https://github.com/openxla/iree/assets/75337/5f7f589d-b602-49b3-96c6-5c9dfa6578fe)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
index 02785f5..c9fd8df 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
@@ -9,8 +9,6 @@
 #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 
@@ -43,39 +41,5 @@
   });
 }
 
-// static
-LogicalResult HALConversionTarget::applyDefaultBufferRewrite(
-    Operation *srcOp, ValueRange operands, StringRef dstOpName,
-    TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
-  OperationState state{srcOp->getLoc(), dstOpName};
-  state.addAttributes(srcOp->getAttrs());
-
-  for (auto [srcOperand, dstOperand] :
-       llvm::zip_equal(srcOp->getOperands(), operands)) {
-    // Check that any type that should have been mapped to buffer view was.
-    // This is just to catch conflicts in type conversions that may sneak in
-    // during development.
-    assert(
-        (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
-         dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
-        "expect that tensors have been mapped to buffer views");
-    state.addOperands({dstOperand});
-  }
-  for (auto resultType : srcOp->getResultTypes()) {
-    if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
-      state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
-    } else {
-      // Normal pass-through result.
-      if (failed(typeConverter.convertType(resultType, state.types))) {
-        return failure();
-      }
-    }
-  }
-
-  auto *dstOp = rewriter.create(state);
-  rewriter.replaceOp(srcOp, dstOp->getResults());
-  return success();
-}
-
 } // namespace iree_compiler
 } // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
index b41dd1f..fd3d489 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
@@ -8,7 +8,6 @@
 #define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_
 
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -22,47 +21,6 @@
 class HALConversionTarget : public ConversionTarget {
 public:
   HALConversionTarget(MLIRContext *context, TypeConverter &typeConverter);
-
-  // Attempts to rewrite an op that may use tensor values into an op using HAL
-  // buffers. See HALOpConversion for more information.
-  static LogicalResult
-  applyDefaultBufferRewrite(Operation *srcOp, ValueRange operands,
-                            StringRef dstOpName, TypeConverter &typeConverter,
-                            ConversionPatternRewriter &rewriter);
-};
-
-// HAL tensor-to-buffer conversion utility.
-// This can be used by dialects to model custom op conversion from a dialect
-// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
-// during conversion the source values will be TensorType and the target values
-// will be IREE::HAL::BufferTypes. Any static information available about the
-// tensor (such as static dimensions, element type, layout, etc) are extracted
-// here and lowered as expanded values.
-//
-// The ABI is currently very basic and will change with the introduction of more
-// dynamic shape logic.
-//
-// Source:
-//   my.tensor_op(%arg0 : tensor<2x4xf32>)
-// Target:
-//   %arg0_view = hal.buffer_view.create %arg0, ...
-//   my.buffer_op(%arg0_view : !hal.buffer_view)
-template <typename SRC, typename DST>
-class HALOpConversion : public OpConversionPattern<SRC> {
-public:
-  HALOpConversion(MLIRContext *context, TypeConverter &typeConverter)
-      : OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}
-
-  LogicalResult
-  matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    return HALConversionTarget::applyDefaultBufferRewrite(
-        srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter,
-        rewriter);
-  }
-
-protected:
-  TypeConverter &typeConverter;
 };
 
 } // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 9c7ebe0..d0a809e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -14,6 +14,7 @@
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -23,6 +24,14 @@
 namespace mlir {
 namespace iree_compiler {
 
+static llvm::cl::opt<bool> clExternalResourcesMappable(
+    "iree-stream-external-resources-mappable",
+    llvm::cl::desc("Allocates external resources as host-visible and mappable. "
+                   "This can degrade performance and introduce allocation "
+                   "overhead and staging buffers for readback on the host "
+                   "should be managed by the calling application instead."),
+    llvm::cl::init(false));
+
 namespace {
 
 static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
@@ -263,17 +272,21 @@
   default:
     break;
   case IREE::Stream::Lifetime::External:
-    // #yolo; these come from/go to outside the program.
-    // Today we assume they are device-local|host-visible just for
-    // practical purposes but that does not have to be true. We really
-    // want this to be something we analyze and handle on the edges
-    // (transferring devices/etc if needed).
-    memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
-                  IREE::HAL::MemoryTypeBitfield::HostVisible;
-    // NOTE: we may not map it but users may after they get them back.
-    // Another reason we should annotate this - having a buffer be
-    // mappable is potentially expensive (may get a 2nd copy in memory!).
-    bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+    if (clExternalResourcesMappable) {
+      // #yolo; these come from/go to outside the program.
+      // Today we assume they are device-local|host-visible just for
+      // practical purposes but that does not have to be true. We really
+      // want this to be something we analyze and handle on the edges
+      // (transferring devices/etc if needed).
+      memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
+                    IREE::HAL::MemoryTypeBitfield::HostVisible;
+      // NOTE: we may not map it but users may after they get them back.
+      // Another reason we should annotate this - having a buffer be
+      // mappable is potentially expensive (may get a 2nd copy in memory!).
+      bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
+    } else {
+      memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+    }
     break;
   }
   return success();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 2aca6c5..d45978d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -80,8 +80,8 @@
     %arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}
 
     // CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator>
-    // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal")
-    // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}Mapping{{.+}}")
+    // CHECK-SAME: type("DeviceVisible|DeviceLocal")
+    // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
     // CHECK-SAME: : !hal.buffer{%c16}
     %result_resource = stream.resource.alloc uninitialized : !stream.resource<external>{%c16}
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index e036755..1fa81ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -307,7 +307,27 @@
           getState() ^= targetUsage.getState();
         })
         .Case([&](IREE::Stream::TensorImportOp op) {
-          removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+          auto targetType =
+              llvm::cast<IREE::Stream::ResourceType>(op.getResult().getType());
+          switch (targetType.getLifetime()) {
+          default:
+          case IREE::Stream::Lifetime::External:
+            removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+            break;
+          case IREE::Stream::Lifetime::Staging:
+            removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
+                              NOT_STAGING_WRITE);
+            break;
+          case IREE::Stream::Lifetime::Transient:
+            removeAssumedBits(NOT_MUTATED);
+            break;
+          case IREE::Stream::Lifetime::Variable:
+            removeAssumedBits(NOT_MUTATED | NOT_GLOBAL_READ | NOT_GLOBAL_WRITE);
+            break;
+          case IREE::Stream::Lifetime::Constant:
+            removeAssumedBits(NOT_CONSTANT);
+            break;
+          }
           auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
               *this, Position::forValue(op.getResult()),
               DFX::Resolution::REQUIRED);
@@ -497,7 +517,6 @@
               *this, Position::forValue(op->getOperand(operandIdx)),
               DFX::Resolution::REQUIRED);
           getState() ^= operandUsage.getState();
-
           auto &beforeUsage = solver.getElementFor<ValueResourceUsage>(
               *this,
               Position::forValue(op.getBeforeBody()->getArgument(operandIdx)),
@@ -510,13 +529,11 @@
               *this, Position::forValue(op->getOperand(operandIdx)),
               DFX::Resolution::REQUIRED);
           getState() ^= operandUsage.getState();
-
           auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
               *this,
               Position::forValue(op->getParentOp()->getResult(operandIdx - 1)),
               DFX::Resolution::REQUIRED);
           getState() ^= parentUsage.getState();
-
           if (auto whileOp =
                   dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
             auto value = Position::forValue(
@@ -532,14 +549,12 @@
                 *this, Position::forValue(op->getOperand(operandIdx)),
                 DFX::Resolution::REQUIRED);
             getState() ^= operandUsage.getState();
-
             auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
                 *this,
                 Position::forValue(op->getParentOp()->getResult(operandIdx)),
                 DFX::Resolution::REQUIRED);
             getState() ^= parentUsage.getState();
           }
-
           if (auto whileOp =
                   dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
             auto value =
@@ -589,7 +604,33 @@
           removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE);
         })
         .Case([&](IREE::Stream::TensorExportOp op) {
-          removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+          auto sourceType =
+              llvm::cast<IREE::Stream::ResourceType>(op.getSource().getType());
+          switch (sourceType.getLifetime()) {
+          default:
+          case IREE::Stream::Lifetime::External:
+            removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+            break;
+          case IREE::Stream::Lifetime::Staging:
+            removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
+                              NOT_STAGING_WRITE | NOT_TRANSFER_READ |
+                              NOT_TRANSFER_WRITE);
+            break;
+          case IREE::Stream::Lifetime::Transient:
+            removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
+                              NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
+                              NOT_DISPATCH_WRITE);
+            break;
+          case IREE::Stream::Lifetime::Variable:
+            removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
+                              NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
+                              NOT_DISPATCH_WRITE);
+            break;
+          case IREE::Stream::Lifetime::Constant:
+            removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_READ |
+                              NOT_DISPATCH_READ);
+            break;
+          }
         })
         .Case([&](IREE::Stream::TensorTraceOp op) {
           removeAssumedBits(NOT_STAGING_READ);
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
index 0644bda..4dcde8c 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@
     ],
     deps = [
         "//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+        "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/VM/Conversion",
         "//compiler/src/iree/compiler/Modules/Check/IR",
         "@llvm-project//mlir:Pass",
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
index 582a6ad..c55d771 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
@@ -21,6 +21,7 @@
     MLIRPass
     MLIRTransforms
     iree::compiler::Dialect::HAL::Conversion
+    iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::VM::Conversion
     iree::compiler::Modules::Check::IR
   PUBLIC
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
index 82da66b..10cdbb3 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -7,6 +7,8 @@
 #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"
 
 #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
 #include "iree/compiler/Modules/Check/IR/CheckOps.h"
 #include "mlir/Pass/Pass.h"
@@ -60,17 +62,90 @@
       context, importSymbols, typeConverter, "check.expect_almost_eq");
 }
 
+// Attempts to rewrite an op that may use tensor values into an op using HAL
+// buffers.
+static LogicalResult applyDefaultCheckBufferRewrite(
+    Operation *srcOp, ValueRange operands, StringRef dstOpName,
+    TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+  OperationState state{srcOp->getLoc(), dstOpName};
+  state.addAttributes(srcOp->getAttrs());
+
+  // Add device argument.
+  Value device = rewriter.create<IREE::HAL::ExSharedDeviceOp>(srcOp->getLoc());
+  state.addOperands({device});
+
+  for (auto [srcOperand, dstOperand] :
+       llvm::zip_equal(srcOp->getOperands(), operands)) {
+    // Check that any type that should have been mapped to buffer view was.
+    // This is just to catch conflicts in type conversions that may sneak in
+    // during development.
+    assert(
+        (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
+         dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
+        "expect that tensors have been mapped to buffer views");
+    state.addOperands({dstOperand});
+  }
+  for (auto resultType : srcOp->getResultTypes()) {
+    if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
+      state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
+    } else {
+      // Normal pass-through result.
+      if (failed(typeConverter.convertType(resultType, state.types))) {
+        return failure();
+      }
+    }
+  }
+
+  auto *dstOp = rewriter.create(state);
+  rewriter.replaceOp(srcOp, dstOp->getResults());
+  return success();
+}
+
+// HAL tensor-to-buffer conversion utility.
+// This can be used by dialects to model custom op conversion from a dialect
+// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
+// during conversion the source values will be TensorType and the target values
+// will be IREE::HAL::BufferTypes. Any static information available about the
+// tensor (such as static dimensions, element type, layout, etc) are extracted
+// here and lowered as expanded values.
+//
+// The ABI is currently very basic and will change with the introduction of more
+// dynamic shape logic.
+//
+// Source:
+//   my.tensor_op(%arg0 : tensor<2x4xf32>)
+// Target:
+//   %arg0_view = hal.buffer_view.create %arg0, ...
+//   my.buffer_op(%arg0_view : !hal.buffer_view)
+template <typename SRC, typename DST>
+class HALCheckOpConversion : public OpConversionPattern<SRC> {
+public:
+  HALCheckOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return applyDefaultCheckBufferRewrite(srcOp, adaptor.getOperands(),
+                                          DST::getOperationName(),
+                                          typeConverter, rewriter);
+  }
+
+protected:
+  TypeConverter &typeConverter;
+};
+
 void populateCheckToHALPatterns(MLIRContext *context,
                                 RewritePatternSet &patterns,
                                 TypeConverter &typeConverter) {
   // The same op handles both tensors and buffer views.
-  patterns
-      .insert<HALOpConversion<IREE::Check::ExpectAllTrueOp,
-                              IREE::Check::ExpectAllTrueOp>,
-              HALOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
-              HALOpConversion<IREE::Check::ExpectAlmostEqOp,
-                              IREE::Check::ExpectAlmostEqOp>>(context,
-                                                              typeConverter);
+  patterns.insert<
+      HALCheckOpConversion<IREE::Check::ExpectAllTrueOp,
+                           IREE::Check::ExpectAllTrueOp>,
+      HALCheckOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
+      HALCheckOpConversion<IREE::Check::ExpectAlmostEqOp,
+                           IREE::Check::ExpectAlmostEqOp>>(context,
+                                                           typeConverter);
 }
 
 } // namespace Check
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
index dff0294..e55f3d2 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel
@@ -57,6 +57,7 @@
         ":IR",
         ":check_ops_gen",
         "//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+        "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
         "//compiler/src/iree/compiler/Dialect/VM/Conversion",
         "//compiler/src/iree/compiler/Modules/Check:check_imports",
         "//compiler/src/iree/compiler/Modules/Check/Conversion",
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
index c3a8574..b0928ce 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt
@@ -42,6 +42,7 @@
     MLIRParser
     MLIRTransforms
     iree::compiler::Dialect::HAL::Conversion
+    iree::compiler::Dialect::HAL::IR::HALDialect
     iree::compiler::Dialect::VM::Conversion
     iree::compiler::Modules::Check::Conversion
     iree::compiler::Modules::Check::check_imports
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
index dbdb4e1..554baa6 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Modules/Check/IR/CheckDialect.h"
 
 #include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
 #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"
 #include "iree/compiler/Modules/Check/IR/CheckOps.h"
@@ -57,6 +58,8 @@
 
 CheckDialect::CheckDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context, TypeID::get<CheckDialect>()) {
+  context->loadDialect<IREE::HAL::HALDialect>();
+
   addInterfaces<CheckToVmConversionInterface>();
   addInterfaces<CheckToHalConversionInterface>();
 #define GET_OP_LIST
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
index a651bfe..69cfda7 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp
@@ -24,7 +24,7 @@
   LogicalResult matchAndRewrite(SrcOp op,
                                 PatternRewriter &rewriter) const override {
     auto rhs = rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue());
-    rewriter.replaceOpWithNewOp<DstOp>(op, op.getLhs(), rhs);
+    rewriter.replaceOpWithNewOp<DstOp>(op, op.getDevice(), op.getLhs(), rhs);
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
index 9d0b1b3..59c2236 100644
--- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
+++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td
@@ -36,7 +36,6 @@
   let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)";
 }
 
-
 def CHECK_ExpectFalseOp : Op<CHECK_Dialect, "expect_false"> {
   let summary = [{Checks that the operand is false}];
   let description = [{
@@ -64,18 +63,24 @@
     Issues a non-fatal failure if the verification fails.
 
     ```mlir
-    check.expect_all_true(%arg0) : !hal.buffer_view
+    check.expect_all_true<%device>(%arg0) : !hal.buffer_view
     check.expect_all_true(%arg1) : tensor<2x2xi32>
     ```
   }];
 
-  let arguments =
-    (ins AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand);
+  let arguments = (ins
+    Optional<HAL_Device>:$device,
+    AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand
+  );
 
-  let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)";
+  let assemblyFormat = [{
+    (`` `<` $device^ `>`)?
+    `` `(` $operand `)` attr-dict `:` type($operand)
+  }];
 }
 
-def CHECK_ExpectEqOp : Op<CHECK_Dialect, "expect_eq", [SameTypeOperands]> {
+def CHECK_ExpectEqOp :
+    Op<CHECK_Dialect, "expect_eq", [AllTypesMatch<["lhs", "rhs"]>]> {
   let summary = [{Checks that the tensor or buffer view operands are equal}];
   let description = [{
     Verifies that the operands are exactly equal.
@@ -88,11 +93,15 @@
   }];
 
   let arguments = (ins
-      AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs,
-      AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs
+    Optional<HAL_Device>:$device,
+    AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs,
+    AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs
   );
 
-  let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)";
+  let assemblyFormat = [{
+    (`` `<` $device^ `>`)?
+    `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)
+  }];
 }
 
 def CHECK_ExpectEqConstOp :
@@ -111,17 +120,21 @@
   }];
 
   let arguments = (ins
+    Optional<HAL_Device>:$device,
     AnyTensor:$lhs,
     ElementsAttr:$value
   );
 
   let hasCanonicalizer = 1;
 
-  let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)";
+  let assemblyFormat = [{
+    (`` `<` $device^ `>`)?
+    `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs)
+  }];
 }
 
 def CHECK_ExpectAlmostEqOp :
-    Op<CHECK_Dialect, "expect_almost_eq", [SameTypeOperands]> {
+    Op<CHECK_Dialect, "expect_almost_eq", [AllTypesMatch<["lhs", "rhs"]>]> {
   let summary = [{Checks that the operands are almost equal}];
   let description = [{
     Verifies that the buffer view or tensor operands with float elements are
@@ -135,11 +148,15 @@
   }];
 
   let arguments = (ins
-      AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs,
-      AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs
+    Optional<HAL_Device>:$device,
+    AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs,
+    AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs
   );
 
-  let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)";
+  let assemblyFormat = [{
+    (`` `<` $device^ `>`)?
+    `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)
+  }];
 }
 
 def CHECK_ExpectAlmostEqConstOp :
@@ -160,13 +177,17 @@
   }];
 
   let arguments = (ins
+    Optional<HAL_Device>:$device,
     TensorOf<[AnyFloat]>:$lhs,
     ElementsAttr:$value
   );
 
   let hasCanonicalizer = 1;
 
-  let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)";
+  let assemblyFormat = [{
+    (`` `<` $device^ `>`)?
+    `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs)
+  }];
 }
 
 #endif  // IREE_MODULES_CHECK_DIALECT_CHECK_OPS
diff --git a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
index 67bae93..63b9d72 100644
--- a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
+++ b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir
@@ -15,15 +15,18 @@
 )
 
 vm.import private optional @expect_all_true(
+  %device : !vm.ref<!hal.device>,
   %operand : !vm.ref<!hal.buffer_view>,
 )
 
 vm.import private optional @expect_eq(
+  %device : !vm.ref<!hal.device>,
   %lhs : !vm.ref<!hal.buffer_view>,
   %rhs : !vm.ref<!hal.buffer_view>
 )
 
 vm.import private optional @expect_almost_eq(
+  %device : !vm.ref<!hal.device>,
   %lhs : !vm.ref<!hal.buffer_view>,
   %rhs : !vm.ref<!hal.buffer_view>
 )
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index b53bcd0..a5e8788 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -623,7 +623,7 @@
   // allocator is set on the device.
   iree_status_t status = iree_ok_status();
   if (device->supports_memory_pools &&
-      !iree_any_bit_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+      !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
     status = iree_hal_cuda2_memory_pools_alloca(
         &device->memory_pools, device->dispatch_cu_stream, pool, params,
         allocation_size, out_buffer);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 4aaba55..cf3bd7f 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -560,7 +560,7 @@
   // allocator is set on the device.
   iree_status_t status = iree_ok_status();
   if (device->supports_memory_pools &&
-      !iree_any_bit_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+      !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
     status = iree_hal_cuda_memory_pools_alloca(&device->memory_pools,
                                                device->stream, pool, params,
                                                allocation_size, out_buffer);
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index 67f1947..7623fb5 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -197,6 +197,9 @@
     IREE_RETURN_IF_ERROR(
         iree_vm_list_create(iree_vm_make_undefined_type_def(), args.size(),
                             iree_allocator_system(), &inputs_));
+    iree_vm_ref_t device_ref = iree_hal_device_retain_ref(device_);
+    IREE_RETURN_IF_ERROR(
+        iree_vm_list_push_ref_move(inputs_.get(), &device_ref));
     for (auto& arg : args) {
       iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg.get());
       IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_.get(), &arg_ref));
diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc
index b417eef..edbb9fe 100644
--- a/runtime/src/iree/modules/check/module.cc
+++ b/runtime/src/iree/modules/check/module.cc
@@ -155,6 +155,100 @@
                           "unsupported element type %s", element_type_str);
 }
 
+static StatusOr<std::vector<vm::ref<iree_hal_buffer_view_t>>>
+TransferBuffersToHost(
+    iree_hal_device_t* device,
+    const iree::span<const vm::ref<iree_hal_buffer_view_t>> source_views) {
+  IREE_TRACE_SCOPE();
+
+  // If all buffers are already host-accessible we can skip the transfer.
+  std::vector<vm::ref<iree_hal_buffer_view_t>> target_views;
+  bool requires_transfer = false;
+  for (auto& source_view : source_views) {
+    iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(source_view.get());
+    if (!iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+                           IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) ||
+        !iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer),
+                           IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) {
+      requires_transfer = true;
+    }
+  }
+  if (!requires_transfer) {
+    for (auto& source_view : source_views) target_views.push_back(source_view);
+    return std::move(target_views);
+  }
+
+  vm::ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+      device,
+      IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+          IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY, 0,
+      &command_buffer));
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_begin(command_buffer.get()));
+
+  iree_hal_buffer_params_t target_params = {
+      /*.usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+      /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL,
+      /*.type=*/
+      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+      /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*.min_alignment=*/0,
+  };
+  for (size_t i = 0; i < source_views.size(); ++i) {
+    iree_hal_buffer_t* source_buffer =
+        iree_hal_buffer_view_buffer(source_views[i].get());
+    iree_device_size_t buffer_length =
+        iree_hal_buffer_byte_length(source_buffer);
+    vm::ref<iree_hal_buffer_t> target_buffer;
+    IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+        iree_hal_device_allocator(device), target_params, buffer_length,
+        &target_buffer));
+    IREE_RETURN_IF_ERROR(iree_hal_command_buffer_copy_buffer(
+        command_buffer.get(), source_buffer, 0, target_buffer.get(), 0,
+        buffer_length));
+    vm::ref<iree_hal_buffer_view_t> target_view;
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create_like(
+        target_buffer.get(), source_views[i].get(),
+        iree_hal_device_host_allocator(device), &target_view));
+    target_views.push_back(std::move(target_view));
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_end(command_buffer.get()));
+  vm::ref<iree_hal_semaphore_t> semaphore;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
+  vm::ref<iree_hal_fence_t> fence;
+  IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(
+      semaphore.get(), 1ull, iree_hal_device_host_allocator(device), &fence));
+  IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
+      device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+      iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_fence_wait(fence.get(), iree_infinite_timeout()));
+  return std::move(target_views);
+}
+
+static Status TransferToHost(iree_hal_device_t* device,
+                             vm::ref<iree_hal_buffer_view_t>& buffer_view) {
+  IREE_TRACE_SCOPE();
+  IREE_ASSIGN_OR_RETURN(auto target_views,
+                        TransferBuffersToHost(device, {buffer_view}));
+  buffer_view = std::move(target_views[0]);
+  return OkStatus();
+}
+
+static Status TransferToHost(iree_hal_device_t* device,
+                             vm::ref<iree_hal_buffer_view_t>& buffer_view_a,
+                             vm::ref<iree_hal_buffer_view_t>& buffer_view_b) {
+  IREE_TRACE_SCOPE();
+  IREE_ASSIGN_OR_RETURN(
+      auto target_views,
+      TransferBuffersToHost(device, {buffer_view_a, buffer_view_b}));
+  buffer_view_a = std::move(target_views[0]);
+  buffer_view_b = std::move(target_views[1]);
+  return OkStatus();
+}
+
 // Per-context module state.
 // This can contain "globals" and other arbitrary state.
 //
@@ -177,7 +271,9 @@
     return OkStatus();
   }
 
-  Status ExpectAllTrue(vm::ref<iree_hal_buffer_view_t> operand) {
+  Status ExpectAllTrue(vm::ref<iree_hal_device_t> device,
+                       vm::ref<iree_hal_buffer_view_t> operand) {
+    IREE_RETURN_IF_ERROR(TransferToHost(device.get(), operand));
     auto* view = operand.get();
     iree_hal_element_type_t element_type =
         iree_hal_buffer_view_element_type(view);
@@ -193,8 +289,10 @@
     return OkStatus();
   }
 
-  Status ExpectEq(vm::ref<iree_hal_buffer_view_t> lhs_ref,
+  Status ExpectEq(vm::ref<iree_hal_device_t> device,
+                  vm::ref<iree_hal_buffer_view_t> lhs_ref,
                   vm::ref<iree_hal_buffer_view_t> rhs_ref) {
+    IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref));
     auto* lhs = lhs_ref.get();
     auto* rhs = rhs_ref.get();
 
@@ -272,8 +370,10 @@
     return OkStatus();
   }
 
-  Status ExpectAlmostEq(vm::ref<iree_hal_buffer_view_t> lhs_ref,
+  Status ExpectAlmostEq(vm::ref<iree_hal_device_t> device,
+                        vm::ref<iree_hal_buffer_view_t> lhs_ref,
                         vm::ref<iree_hal_buffer_view_t> rhs_ref) {
+    IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref));
     auto* lhs = lhs_ref.get();
     auto* rhs = rhs_ref.get();
 
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index ff5aa8e..40d8bc3 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -14,9 +14,10 @@
 }
 
 func.func @expect_all_true() {
+  %device = hal.ex.shared_device : !hal.device
   %all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32>
   %all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view
-  check.expect_all_true(%all_true_view) : !hal.buffer_view
+  check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view
   return
 }
 
diff --git a/runtime/src/iree/modules/hal/types.c b/runtime/src/iree/modules/hal/types.c
index 0c7e0d7..52ce5a2 100644
--- a/runtime/src/iree/modules/hal/types.c
+++ b/runtime/src/iree/modules/hal/types.c
@@ -205,7 +205,7 @@
 
 IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_retain(
     iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_t* value) {
-  iree_vm_ref_t value_ref;
+  iree_vm_ref_t value_ref = iree_vm_ref_null();
   IREE_RETURN_IF_ERROR(
       iree_vm_ref_wrap_assign(value, iree_hal_buffer_type(), &value_ref));
   return iree_vm_list_set_ref_retain(list, i, &value_ref);
@@ -226,7 +226,7 @@
 
 IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_view_retain(
     iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_view_t* value) {
-  iree_vm_ref_t value_ref;
+  iree_vm_ref_t value_ref = iree_vm_ref_null();
   IREE_RETURN_IF_ERROR(
       iree_vm_ref_wrap_assign(value, iree_hal_buffer_view_type(), &value_ref));
   return iree_vm_list_set_ref_retain(list, i, &value_ref);
@@ -247,7 +247,7 @@
 
 IREE_API_EXPORT iree_status_t iree_vm_list_set_fence_retain(
     iree_vm_list_t* list, iree_host_size_t i, iree_hal_fence_t* value) {
-  iree_vm_ref_t value_ref;
+  iree_vm_ref_t value_ref = iree_vm_ref_null();
   IREE_RETURN_IF_ERROR(
       iree_vm_ref_wrap_assign(value, iree_hal_fence_type(), &value_ref));
   return iree_vm_list_set_ref_retain(list, i, &value_ref);
diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c
index ad5e674..2af3db4 100644
--- a/runtime/src/iree/tooling/run_module.c
+++ b/runtime/src/iree/tooling/run_module.c
@@ -246,6 +246,22 @@
         "processing instrument data");
   }
 
+  // Transfer outputs to the host so they can be processed. Only required when
+  // using full HAL device-based execution.
+  if (iree_status_is_ok(status) && device != NULL) {
+    iree_hal_buffer_params_t target_params = {
+        .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+        .access = IREE_HAL_MEMORY_ACCESS_ALL,
+        .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+        .min_alignment = 0,
+    };
+    status = iree_tooling_transfer_variant_list(
+        device, outputs, device_allocator, target_params,
+        /*wait_fence=*/NULL, /*signal_fence=*/NULL);
+  }
+
   // Handle either printing/writing the outputs or checking them against
   // expected values (basic pass/fail testing).
   if (iree_status_is_ok(status)) {
diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c
index b21eada..70e2e77 100644
--- a/runtime/src/iree/tooling/vm_util.c
+++ b/runtime/src/iree/tooling/vm_util.c
@@ -324,6 +324,187 @@
   return status;
 }
 
+static bool iree_tooling_requires_buffer_transfer(
+    iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) {
+  return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer),
+                            target_params.type) ||
+         !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer),
+                            target_params.usage);
+}
+
+static iree_status_t iree_tooling_setup_buffer_transfer(
+    iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer,
+    iree_hal_allocator_t* target_allocator,
+    iree_hal_buffer_params_t target_params,
+    iree_hal_buffer_t** out_target_buffer) {
+  IREE_ASSERT_ARGUMENT(command_buffer);
+  IREE_ASSERT_ARGUMENT(source_buffer);
+  IREE_ASSERT_ARGUMENT(target_allocator);
+  IREE_ASSERT_ARGUMENT(out_target_buffer);
+  *out_target_buffer = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_buffer_t* target_buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_allocator_allocate_buffer(
+              target_allocator, target_params,
+              iree_hal_buffer_allocation_size(source_buffer), &target_buffer));
+
+  iree_status_t status = iree_hal_command_buffer_copy_buffer(
+      command_buffer, source_buffer, 0, target_buffer, 0,
+      iree_hal_buffer_byte_length(source_buffer));
+
+  if (iree_status_is_ok(status)) {
+    *out_target_buffer = target_buffer;
+  } else {
+    iree_hal_buffer_release(target_buffer);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_tooling_submit_transfer(
+    iree_hal_device_t* device, iree_hal_fence_t* wait_fence,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_status_t status = iree_ok_status();
+
+  bool needs_wait = signal_fence == NULL;
+  if (needs_wait) {
+    iree_hal_semaphore_t* semaphore = NULL;
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_hal_semaphore_create(device, 0ull, &semaphore));
+    status = iree_hal_fence_create_at(
+        semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence);
+    iree_hal_semaphore_release(semaphore);
+  } else {
+    iree_hal_fence_retain(signal_fence);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_execute(
+        device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
+        iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer);
+  }
+
+  if (iree_status_is_ok(status) && needs_wait) {
+    status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout());
+  }
+
+  iree_hal_fence_release(signal_fence);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_tooling_transfer_variant_list(
+    iree_hal_device_t* device, iree_vm_list_t* list,
+    iree_hal_allocator_t* target_allocator,
+    iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence,
+    iree_hal_fence_t* signal_fence) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(list);
+  IREE_ASSERT_ARGUMENT(target_allocator);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // If all buffers are already host-accessible we can skip the transfer.
+  bool requires_transfer = false;
+  for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
+    iree_vm_ref_t value = iree_vm_ref_null();
+    IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value));
+    if (iree_hal_buffer_isa(value)) {
+      iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value);
+      if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) {
+        requires_transfer = true;
+        break;
+      }
+    } else if (iree_hal_buffer_view_isa(value)) {
+      iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value);
+      iree_hal_buffer_t* source_buffer =
+          iree_hal_buffer_view_buffer(source_view);
+      if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) {
+        requires_transfer = true;
+        break;
+      }
+    }
+  }
+  if (!requires_transfer) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
+  }
+
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_command_buffer_create(
+              device,
+              IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+                  IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
+              IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity,
+              /*binding_capacity=*/0, &command_buffer));
+
+  iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
+  if (iree_status_is_ok(status)) {
+    for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
+      iree_vm_ref_t value = iree_vm_ref_null();
+      IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value));
+      if (iree_hal_buffer_isa(value)) {
+        iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value);
+        if (!iree_tooling_requires_buffer_transfer(source_buffer,
+                                                   target_params)) {
+          // Already ok.
+          continue;
+        }
+        iree_hal_buffer_t* target_buffer = NULL;
+        status = iree_tooling_setup_buffer_transfer(
+            command_buffer, source_buffer, target_allocator, target_params,
+            &target_buffer);
+        if (!iree_status_is_ok(status)) break;
+        status = iree_vm_list_set_buffer_retain(list, i, target_buffer);
+        iree_hal_buffer_release(target_buffer);
+        if (!iree_status_is_ok(status)) break;
+      } else if (iree_hal_buffer_view_isa(value)) {
+        iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value);
+        iree_hal_buffer_t* source_buffer =
+            iree_hal_buffer_view_buffer(source_view);
+        if (!iree_tooling_requires_buffer_transfer(source_buffer,
+                                                   target_params)) {
+          // Already ok.
+          continue;
+        }
+        iree_hal_buffer_t* target_buffer = NULL;
+        status = iree_tooling_setup_buffer_transfer(
+            command_buffer, source_buffer, target_allocator, target_params,
+            &target_buffer);
+        if (!iree_status_is_ok(status)) break;
+        iree_hal_buffer_view_t* target_view = NULL;
+        status = iree_hal_buffer_view_create_like(
+            target_buffer, source_view,
+            iree_hal_allocator_host_allocator(target_allocator), &target_view);
+        iree_hal_buffer_release(target_buffer);
+        if (!iree_status_is_ok(status)) break;
+        status = iree_vm_list_set_buffer_view_retain(list, i, target_view);
+        iree_hal_buffer_view_release(target_view);
+        if (!iree_status_is_ok(status)) break;
+      }
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_command_buffer_end(command_buffer);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_tooling_submit_transfer(device, wait_fence,
+                                          target_params.queue_affinity,
+                                          command_buffer, signal_fence);
+  }
+
+  iree_hal_command_buffer_release(command_buffer);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
 #define IREE_PRINTVARIANT_CASE_I(SIZE, B, V)  \
   case IREE_VM_VALUE_TYPE_I##SIZE:            \
     return iree_string_builder_append_format( \
diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h
index e2a0311..bc9ca00 100644
--- a/runtime/src/iree/tooling/vm_util.h
+++ b/runtime/src/iree/tooling/vm_util.h
@@ -54,6 +54,16 @@
     iree_hal_device_t* device, iree_hal_fence_t* wait_fence,
     iree_hal_fence_t** out_signal_fence);
 
+// Transfers all buffers in |list| to ones using |target_params|.
+// If no |wait_fence| is provided then the transfer will begin immediately.
+// If no |signal_fence| is provided then the call will block until the transfer
+// completes.
+iree_status_t iree_tooling_transfer_variant_list(
+    iree_hal_device_t* device, iree_vm_list_t* list,
+    iree_hal_allocator_t* target_allocator,
+    iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence,
+    iree_hal_fence_t* signal_fence);
+
 // Appends a variant list of VM scalars and buffers to |builder|.
 // |list_name| will be printed alongside each element ordinal.
 //
diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel
index 75c4563..11af40c 100644
--- a/tools/BUILD.bazel
+++ b/tools/BUILD.bazel
@@ -210,6 +210,7 @@
         "//runtime/src/iree/modules/hal",
         "//runtime/src/iree/tooling:device_util",
         "//runtime/src/iree/tooling:trace_replay",
+        "//runtime/src/iree/tooling:vm_util",
         "//runtime/src/iree/tooling:yaml_util",
         "//runtime/src/iree/vm",
         "@com_github_yaml_libyaml//:yaml",
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 3cf3a0a..2445774 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -215,6 +215,7 @@
     iree::modules::hal
     iree::tooling::device_util
     iree::tooling::trace_replay
+    iree::tooling::vm_util
     iree::tooling::yaml_util
     iree::vm
     yaml
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 83b7343..758ae35 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -19,6 +19,7 @@
 #include "iree/modules/hal/module.h"
 #include "iree/tooling/device_util.h"
 #include "iree/tooling/trace_replay.h"
+#include "iree/tooling/vm_util.h"
 #include "iree/tooling/yaml_util.h"
 #include "iree/vm/api.h"
 
@@ -200,10 +201,8 @@
     iree_hal_buffer_view_t* buffer_view,
     enum iree_hal_memory_access_bits_t access,
     iree_hal_buffer_mapping_t* mapping) {
-  // Really validate host-local, not just host-visible: callers may rely on
-  // host-coherency.
   IREE_RETURN_IF_ERROR(
-      validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_LOCAL));
+      validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE));
   if (iree_hal_buffer_view_encoding_type(buffer_view) !=
       IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
     return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
@@ -1055,39 +1054,43 @@
       replay->device, device_allocator, device_inputs, &host_inputs));
 
   // Invoke the function to produce the actual result.
-  iree_vm_list_t* device_outputs = NULL;
+  iree_vm_list_t* outputs = NULL;
   IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
                                     /*initial_capacity=*/8,
-                                    replay->host_allocator, &device_outputs));
+                                    replay->host_allocator, &outputs));
   IREE_CHECK_OK(iree_vm_invoke(
       replay->context, function, IREE_VM_INVOCATION_FLAG_NONE,
-      /*policy=*/NULL, device_inputs, device_outputs, replay->host_allocator));
+      /*policy=*/NULL, device_inputs, outputs, replay->host_allocator));
   iree_vm_list_release(device_inputs);
 
-  // Get the device_actual_result from the device_outputs.
-  iree_hal_buffer_view_t* device_actual_result;
-  IREE_CHECK_OK(
-      get_item_as_buffer_view(device_outputs, 0, &device_actual_result));
+  // Transfer device buffers to host buffers.
+  iree_hal_buffer_params_t host_params = {
+      .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+      .access = IREE_HAL_MEMORY_ACCESS_ALL,
+      .type =
+          IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+      .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+      .min_alignment = 0,
+  };
+  IREE_CHECK_OK(iree_tooling_transfer_variant_list(
+      replay->device, outputs, device_allocator, host_params,
+      /*wait_fence=*/NULL, /*signal_fence=*/NULL));
 
-  // Copy the results to a host local buffer to be able to map it.
-  iree_hal_buffer_view_t* host_actual_result = NULL;
-  IREE_CHECK_OK(copy_device_buffer_view_to_host(
-      replay->device, device_allocator, device_actual_result,
-      &host_actual_result));
+  // Get the actual result computed by the program.
+  iree_hal_buffer_view_t* actual_result;
+  IREE_CHECK_OK(get_item_as_buffer_view(outputs, 0, &actual_result));
 
-  // Allocate host_expected_result with same shape as host_actual_result.
+  // Allocate host_expected_result with same shape as actual_result.
   iree_hal_buffer_view_t* host_expected_result = NULL;
-  IREE_CHECK_OK(allocate_host_buffer_view_like(replay->device, device_allocator,
-                                               host_actual_result,
-                                               &host_expected_result));
+  IREE_CHECK_OK(allocate_host_buffer_view_like(
+      replay->device, device_allocator, actual_result, &host_expected_result));
 
-  // Check that host_actual_result and host_expected_result agree.
-  iree_status_t status = check_matmul_results(
-      file, host_inputs, host_actual_result, host_expected_result);
+  // Check that actual_result and host_expected_result agree.
+  iree_status_t status = check_matmul_results(file, host_inputs, actual_result,
+                                              host_expected_result);
 
-  iree_vm_list_release(device_outputs);  // releases device_actual_result
+  iree_vm_list_release(outputs);  // releases actual_result
   iree_vm_list_release(host_inputs);
-  iree_hal_buffer_view_release(host_actual_result);
   iree_hal_buffer_view_release(host_expected_result);
   return status;
 }
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index fa46810..b1b39dc 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -197,6 +197,21 @@
 
   yaml_parser_delete(&parser);
 
+  // Transfer outputs to the host so they can be processed.
+  if (iree_status_is_ok(status) && replay.device != NULL) {
+    iree_hal_buffer_params_t target_params = {
+        .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+        .access = IREE_HAL_MEMORY_ACCESS_ALL,
+        .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+        .min_alignment = 0,
+    };
+    status = iree_tooling_transfer_variant_list(
+        replay.device, replay.outputs, iree_hal_device_allocator(replay.device),
+        target_params, /*wait_fence=*/NULL, /*signal_fence=*/NULL);
+  }
+
   // Optionally process outputs from the replay session.
   if (iree_status_is_ok(status)) {
     if (FLAG_output_list().count == 0) {