Updating various tests to the latest changes.
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
index 2939218..8f51c61 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
@@ -56,6 +56,7 @@
TorchInput::createConvertTMTensorToLinalgExtPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
+ pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
index 14b13f9..2332f86 100644
--- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
+++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "materialize_homogeneous_encodings.mlir",
"smoketest_embedded.mlir",
"smoketest_system.mlir",
],
diff --git a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
index dde5618..5eee1f4 100644
--- a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "materialize_homogeneous_encodings.mlir"
"smoketest_embedded.mlir"
"smoketest_system.mlir"
TOOLS
diff --git a/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
new file mode 100644
index 0000000..5d5b591
--- /dev/null
+++ b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
+#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
+module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
+ util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
+ %1 = affine.apply #map()[%0#0, %dim]
+ %2 = affine.apply #map()[%0#1, %dim_0]
+ %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
+ %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
+ util.return %4 : tensor<?x?xf32>
+ }
+}
+// CHECK-LABEL: util.func public @lhs_encoding
+// CHECK: tensor.pack
+// CHECK: tensor.unpack
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 5f97a97..75e4bbd 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -226,7 +226,7 @@
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);
- return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
+ return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
configAttr, executableTargetAttrs);
}
@@ -238,7 +238,7 @@
public:
ROCMTargetBackend(const ROCmOptions &options) : options(options) {}
- std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
+ std::string getLegacyDefaultDeviceID() const override { return "hip"; }
void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
@@ -702,8 +702,8 @@
: PluginSession<ROCMSession, ROCmOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
- // #hal.device.target<"rocm", ...
- targets.add("rocm",
+ // #hal.device.target<"hip", ...
+ targets.add("hip",
[&]() { return std::make_shared<ROCMTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir
index b446ea2..1afe688 100644
--- a/compiler/plugins/target/ROCM/test/smoketest.mlir
+++ b/compiler/plugins/target/ROCM/test/smoketest.mlir
@@ -2,7 +2,7 @@
module attributes {
hal.device.targets = [
- #hal.device.target<"rocm", [
+ #hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
@@ -46,7 +46,7 @@
#loc = loc(unknown)
module attributes {
hal.device.targets = [
- #hal.device.target<"rocm", [
+ #hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index ae7676e..15240f9 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -1,7 +1,7 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
diff --git a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
index 32f2485..b839443 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
+++ b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "materialize_homogeneous_encodings.mlir",
"smoketest.mlir",
],
include = ["*.mlir"],
diff --git a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
index ec5576e..300499d 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
+++ b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "materialize_homogeneous_encodings.mlir"
"smoketest.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
similarity index 69%
rename from compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
rename to compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
index b11839b..037cda0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
@@ -1,36 +1,5 @@
// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
-#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
-#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
-module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
- util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
- %1 = affine.apply #map()[%0#0, %dim]
- %2 = affine.apply #map()[%0#1, %dim_0]
- %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
- ^bb0(%arg1: index, %arg2: index):
- tensor.yield %cst : f32
- } : tensor<?x?xf32> to tensor<?x?xf32>
- %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
- %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
- util.return %4 : tensor<?x?xf32>
- }
-}
-// CHECK-LABEL: util.func public @lhs_encoding
-// CHECK: tensor.pack
-// CHECK: tensor.unpack
-
-// -----
-
#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
index ac3c166..8351c91 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
@@ -565,9 +565,10 @@
if (auto affinityOp =
dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
result.getDefiningOp())) {
- auto &opPVS = solver.getElementFor<OpAffinityPVS>(
- *this, Position::forOperation(result.getOwner()),
- DFX::Resolution::OPTIONAL);
+ auto &opPVS = solver.getOrCreateElementFor<OpAffinityPVS>(
+ Position::forOperation(result.getOwner()), *this,
+ DFX::Resolution::OPTIONAL, /*forceUpdate=*/false,
+ /*updateAfterInit=*/false);
LLVM_DEBUG({
llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
value.printAsOperand(llvm::dbgs(), solver.getAsmState());
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
index bbb7867..027c626 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
@@ -29,7 +29,6 @@
"global_loop_invariant_code_motion.mlir",
"hoist_into_globals.mlir",
"infer_numeric_narrowing.mlir",
- "materialize_homogeneous_encodings.mlir",
"optimize_numerics.mlir",
"propagate_linalg_transpose.mlir",
"raise_special_ops.mlir",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
index 79c75b3..b6823fc 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
@@ -27,7 +27,6 @@
"global_loop_invariant_code_motion.mlir"
"hoist_into_globals.mlir"
"infer_numeric_narrowing.mlir"
- "materialize_homogeneous_encodings.mlir"
"optimize_numerics.mlir"
"propagate_linalg_transpose.mlir"
"raise_special_ops.mlir"
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
index 2725322..c08b434 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
@@ -113,7 +113,7 @@
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
index f328211..2e5b189 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
@@ -97,7 +97,7 @@
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
index 6d9ab66..881d93d 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
@@ -68,7 +68,7 @@
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
index 41b2e61..207ddaf 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
@@ -99,7 +99,7 @@
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
index 4e1bc70..9d7f942 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
@@ -103,7 +103,7 @@
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
index 49e49d3..5b9ab15 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
@@ -68,7 +68,7 @@
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index 7c5012e..c2a310e 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -73,7 +73,6 @@
%p8 = arith.addf %p7, %cp1 : tensor<f32>
%p9 = arith.addf %p8, %cp1 : tensor<f32>
%approximately_1 = arith.addf %p9, %cp1 : tensor<f32>
-
check.expect_almost_eq(%approximately_1, %c1) : tensor<f32>
return
}
diff --git a/samples/simple_embedding/device_vmvx_sync.c b/samples/simple_embedding/device_vmvx_sync.c
index fa5981c..f1f633f 100644
--- a/samples/simple_embedding/device_vmvx_sync.c
+++ b/samples/simple_embedding/device_vmvx_sync.c
@@ -34,7 +34,7 @@
iree_vm_instance_release(instance);
// Use the default host allocator for buffer allocations.
- iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+ iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/samples/static_library/static_library_demo.c b/samples/static_library/static_library_demo.c
index 76a0b6c..e8670c5 100644
--- a/samples/static_library/static_library_demo.c
+++ b/samples/static_library/static_library_demo.c
@@ -42,7 +42,7 @@
&library_loader);
// Use the default host allocator for buffer allocations.
- iree_string_view_t identifier = iree_make_cstring_view("sync");
+ iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/tools/testing/e2e/iree-e2e-conv2d-test.cc b/tools/testing/e2e/iree-e2e-conv2d-test.cc
index 31d02e9..c4158fd 100644
--- a/tools/testing/e2e/iree-e2e-conv2d-test.cc
+++ b/tools/testing/e2e/iree-e2e-conv2d-test.cc
@@ -549,14 +549,17 @@
return EXIT_FAILURE;
}
+ // Run the tests. Note that some modules may be compiled for other platforms
+ // and not have the required architectures for execution within them - to keep
+ // the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), conv2d_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
- bool is_unavailable = iree_status_is_unavailable(status);
+ bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
- exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+ exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}
IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc
index c9c82f9..f2773f0 100644
--- a/tools/testing/e2e/iree-e2e-matmul-test.cc
+++ b/tools/testing/e2e/iree-e2e-matmul-test.cc
@@ -725,14 +725,17 @@
return EXIT_FAILURE;
}
+ // Run the tests. Note that some modules may be compiled for other platforms
+ // and not have the required architectures for execution within them - to keep
+ // the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), matmul_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
- bool is_unavailable = iree_status_is_unavailable(status);
+ bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
- exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+ exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}
IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c
index 926b0ea..2981148 100644
--- a/tools/testing/e2e/test_utils.c
+++ b/tools/testing/e2e/test_utils.c
@@ -413,7 +413,7 @@
return iree_make_status(
// The error status matters. We distinguish "feature not supported"
// which is a normal thing to happen from actual errors.
- IREE_STATUS_UNAVAILABLE,
+ IREE_STATUS_NOT_FOUND,
"target device does not have the required feature '%.*s'",
(int)required_feature.size, required_feature.data);
}
diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h
index f3a18d2..f095537 100644
--- a/tools/testing/e2e/test_utils.h
+++ b/tools/testing/e2e/test_utils.h
@@ -133,7 +133,7 @@
iree_allocator_t host_allocator);
// Returns OK if there are declared requirements on |module| and they are all
-// met and otherwise UNAVAILABLE indicating that the module should not be run.
+// met and otherwise NOT_FOUND indicating that the module should not be run.
iree_status_t iree_test_utils_check_module_requirements(
iree_vm_module_t* module);