Updating various tests to the latest changes.
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
index 2939218..8f51c61 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
@@ -56,6 +56,7 @@
       TorchInput::createConvertTMTensorToLinalgExtPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
+  pm.addNestedPass<func::FuncOp>(createCSEPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
   pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
index 14b13f9..2332f86 100644
--- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
+++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
@@ -16,6 +16,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "materialize_homogeneous_encodings.mlir",
             "smoketest_embedded.mlir",
             "smoketest_system.mlir",
         ],
diff --git a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
index dde5618..5eee1f4 100644
--- a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "materialize_homogeneous_encodings.mlir"
     "smoketest_embedded.mlir"
     "smoketest_system.mlir"
   TOOLS
diff --git a/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
new file mode 100644
index 0000000..5d5b591
--- /dev/null
+++ b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
+#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
+module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
+  util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+    %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
+    %1 = affine.apply #map()[%0#0, %dim]
+    %2 = affine.apply #map()[%0#1, %dim_0]
+    %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %cst : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+    %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
+    %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
+    util.return %4 : tensor<?x?xf32>
+  }
+}
+// CHECK-LABEL: util.func public @lhs_encoding
+// CHECK:         tensor.pack
+// CHECK:         tensor.unpack
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 5f97a97..75e4bbd 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -226,7 +226,7 @@
     targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
         context, "rocm", configAttr, executableTargetAttrs);
 
-    return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
+    return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
                                             configAttr, executableTargetAttrs);
   }
 
@@ -238,7 +238,7 @@
 public:
   ROCMTargetBackend(const ROCmOptions &options) : options(options) {}
 
-  std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
+  std::string getLegacyDefaultDeviceID() const override { return "hip"; }
 
   void getDefaultExecutableTargets(
       MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
@@ -702,8 +702,8 @@
     : PluginSession<ROCMSession, ROCmOptions,
                     PluginActivationPolicy::DefaultActivated> {
   void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
-    // #hal.device.target<"rocm", ...
-    targets.add("rocm",
+    // #hal.device.target<"hip", ...
+    targets.add("hip",
                 [&]() { return std::make_shared<ROCMTargetDevice>(options); });
   }
   void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir
index b446ea2..1afe688 100644
--- a/compiler/plugins/target/ROCM/test/smoketest.mlir
+++ b/compiler/plugins/target/ROCM/test/smoketest.mlir
@@ -2,7 +2,7 @@
 
 module attributes {
   hal.device.targets = [
-    #hal.device.target<"rocm", [
+    #hal.device.target<"hip", [
       #hal.executable.target<"rocm", "rocm-hsaco-fb">
     ]> : !hal.device
   ]
@@ -46,7 +46,7 @@
 #loc = loc(unknown)
 module attributes {
   hal.device.targets = [
-    #hal.device.target<"rocm", [
+    #hal.device.target<"hip", [
       #hal.executable.target<"rocm", "rocm-hsaco-fb">
     ]> : !hal.device
   ]
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index ae7676e..15240f9 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -1,7 +1,7 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
 
 // GFX942: target = #iree_gpu.target<arch = "gfx942",
 // GFX942-SAME: wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8,
diff --git a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
index 32f2485..b839443 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
+++ b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
@@ -16,6 +16,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "materialize_homogeneous_encodings.mlir",
             "smoketest.mlir",
         ],
         include = ["*.mlir"],
diff --git a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
index ec5576e..300499d 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
+++ b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "materialize_homogeneous_encodings.mlir"
     "smoketest.mlir"
   TOOLS
     FileCheck
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
similarity index 69%
rename from compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
rename to compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
index b11839b..037cda0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
@@ -1,36 +1,5 @@
 // RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
 
-#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
-#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
-module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
-  util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
-    %cst = arith.constant 0.000000e+00 : f32
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-    %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
-    %1 = affine.apply #map()[%0#0, %dim]
-    %2 = affine.apply #map()[%0#1, %dim_0]
-    %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
-    ^bb0(%arg1: index, %arg2: index):
-      tensor.yield %cst : f32
-    } : tensor<?x?xf32> to tensor<?x?xf32>
-    %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
-    %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
-    util.return %4 : tensor<?x?xf32>
-  }
-}
-// CHECK-LABEL: util.func public @lhs_encoding
-// CHECK:         tensor.pack
-// CHECK:         tensor.unpack
-
-// -----
-
 #executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
 #map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
index ac3c166..8351c91 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
@@ -565,9 +565,10 @@
         if (auto affinityOp =
                 dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
                     result.getDefiningOp())) {
-          auto &opPVS = solver.getElementFor<OpAffinityPVS>(
-              *this, Position::forOperation(result.getOwner()),
-              DFX::Resolution::OPTIONAL);
+          auto &opPVS = solver.getOrCreateElementFor<OpAffinityPVS>(
+              Position::forOperation(result.getOwner()), *this,
+              DFX::Resolution::OPTIONAL, /*forceUpdate=*/false,
+              /*updateAfterInit=*/false);
           LLVM_DEBUG({
             llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
             value.printAsOperand(llvm::dbgs(), solver.getAsmState());
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
index bbb7867..027c626 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
@@ -29,7 +29,6 @@
             "global_loop_invariant_code_motion.mlir",
             "hoist_into_globals.mlir",
             "infer_numeric_narrowing.mlir",
-            "materialize_homogeneous_encodings.mlir",
             "optimize_numerics.mlir",
             "propagate_linalg_transpose.mlir",
             "raise_special_ops.mlir",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
index 79c75b3..b6823fc 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
@@ -27,7 +27,6 @@
     "global_loop_invariant_code_motion.mlir"
     "hoist_into_globals.mlir"
     "infer_numeric_narrowing.mlir"
-    "materialize_homogeneous_encodings.mlir"
     "optimize_numerics.mlir"
     "propagate_linalg_transpose.mlir"
     "raise_special_ops.mlir"
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
index 2725322..c08b434 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
@@ -113,7 +113,7 @@
     "--iree-opt-aggressively-propagate-transposes=true",
     "--iree-codegen-llvmgpu-use-vector-distribution=true",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
 ]
 
 ###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
index f328211..2e5b189 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
@@ -97,7 +97,7 @@
     "--iree-codegen-llvmgpu-use-vector-distribution",
     "--iree-rocm-waves-per-eu=2",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
 ]
 
 ###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
index 6d9ab66..881d93d 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
@@ -68,7 +68,7 @@
     "--iree-flow-enable-aggressive-fusion=true",
     "--iree-codegen-llvmgpu-use-vector-distribution=true",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
 ]
 
 ###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
index 41b2e61..207ddaf 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
@@ -99,7 +99,7 @@
     "--iree-opt-aggressively-propagate-transposes=true",
     "--iree-codegen-llvmgpu-use-vector-distribution=true",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
     "--iree-scheduling-dump-statistics-format=json",
     "--iree-scheduling-dump-statistics-file=compilation_info.json",
 ]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
index 4e1bc70..9d7f942 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
@@ -103,7 +103,7 @@
     "--iree-codegen-llvmgpu-use-vector-distribution",
     "--iree-rocm-waves-per-eu=2",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
     "--iree-scheduling-dump-statistics-format=json",
     "--iree-scheduling-dump-statistics-file=compilation_info.json",
 ]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
index 49e49d3..5b9ab15 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
@@ -68,7 +68,7 @@
     "--iree-flow-enable-aggressive-fusion=true",
     "--iree-codegen-llvmgpu-use-vector-distribution=true",
     "--iree-execution-model=async-external",
-    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+    "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
     "--iree-scheduling-dump-statistics-format=json",
     "--iree-scheduling-dump-statistics-file=compilation_info.json",
 ]
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index 7c5012e..c2a310e 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -73,7 +73,6 @@
   %p8 = arith.addf %p7, %cp1 : tensor<f32>
   %p9 = arith.addf %p8, %cp1 : tensor<f32>
   %approximately_1 = arith.addf %p9, %cp1 : tensor<f32>
-
   check.expect_almost_eq(%approximately_1, %c1) : tensor<f32>
   return
 }
diff --git a/samples/simple_embedding/device_vmvx_sync.c b/samples/simple_embedding/device_vmvx_sync.c
index fa5981c..f1f633f 100644
--- a/samples/simple_embedding/device_vmvx_sync.c
+++ b/samples/simple_embedding/device_vmvx_sync.c
@@ -34,7 +34,7 @@
   iree_vm_instance_release(instance);
 
   // Use the default host allocator for buffer allocations.
-  iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+  iree_string_view_t identifier = iree_make_cstring_view("local-sync");
   iree_hal_allocator_t* device_allocator = NULL;
   if (iree_status_is_ok(status)) {
     status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/samples/static_library/static_library_demo.c b/samples/static_library/static_library_demo.c
index 76a0b6c..e8670c5 100644
--- a/samples/static_library/static_library_demo.c
+++ b/samples/static_library/static_library_demo.c
@@ -42,7 +42,7 @@
       &library_loader);
 
   // Use the default host allocator for buffer allocations.
-  iree_string_view_t identifier = iree_make_cstring_view("sync");
+  iree_string_view_t identifier = iree_make_cstring_view("local-sync");
   iree_hal_allocator_t* device_allocator = NULL;
   if (iree_status_is_ok(status)) {
     status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/tools/testing/e2e/iree-e2e-conv2d-test.cc b/tools/testing/e2e/iree-e2e-conv2d-test.cc
index 31d02e9..c4158fd 100644
--- a/tools/testing/e2e/iree-e2e-conv2d-test.cc
+++ b/tools/testing/e2e/iree-e2e-conv2d-test.cc
@@ -549,14 +549,17 @@
     return EXIT_FAILURE;
   }
 
+  // Run the tests. Note that some modules may be compiled for other platforms
+  // and not have the required architectures for execution within them - to keep
+  // the test runner dumber we gracefully fail those cases by returning success.
   iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
       iree_allocator_system(), conv2d_test_module_create);
   int exit_code = EXIT_SUCCESS;
   if (!iree_status_is_ok(status)) {
     iree_status_fprint(stderr, status);
-    bool is_unavailable = iree_status_is_unavailable(status);
+    bool is_device_unavailable = iree_status_is_not_found(status);
     iree_status_free(status);
-    exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+    exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
   }
 
   IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc
index c9c82f9..f2773f0 100644
--- a/tools/testing/e2e/iree-e2e-matmul-test.cc
+++ b/tools/testing/e2e/iree-e2e-matmul-test.cc
@@ -725,14 +725,17 @@
     return EXIT_FAILURE;
   }
 
+  // Run the tests. Note that some modules may be compiled for other platforms
+  // and not have the required architectures for execution within them - to keep
+  // the test runner dumber we gracefully fail those cases by returning success.
   iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
       iree_allocator_system(), matmul_test_module_create);
   int exit_code = EXIT_SUCCESS;
   if (!iree_status_is_ok(status)) {
     iree_status_fprint(stderr, status);
-    bool is_unavailable = iree_status_is_unavailable(status);
+    bool is_device_unavailable = iree_status_is_not_found(status);
     iree_status_free(status);
-    exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+    exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
   }
 
   IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c
index 926b0ea..2981148 100644
--- a/tools/testing/e2e/test_utils.c
+++ b/tools/testing/e2e/test_utils.c
@@ -413,7 +413,7 @@
       return iree_make_status(
           // The error status matters. We distinguish "feature not supported"
           // which is a normal thing to happen from actual errors.
-          IREE_STATUS_UNAVAILABLE,
+          IREE_STATUS_NOT_FOUND,
           "target device does not have the required feature '%.*s'",
           (int)required_feature.size, required_feature.data);
     }
diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h
index f3a18d2..f095537 100644
--- a/tools/testing/e2e/test_utils.h
+++ b/tools/testing/e2e/test_utils.h
@@ -133,7 +133,7 @@
     iree_allocator_t host_allocator);
 
 // Returns OK if there are declared requirements on |module| and they are all
-// met and otherwise UNAVAILABLE indicating that the module should not be run.
+// met and otherwise NOT_FOUND indicating that the module should not be run.
 iree_status_t iree_test_utils_check_module_requirements(
     iree_vm_module_t* module);