[DT] Unify encoding materialization pass into a single pass. (#19454)

The revision creates a generic materialization pass and uses it for
backends that implement data-tiling. After months of development, we
identify that the needs of GPU is a superset of the needs of CPU. To be
more specific, it has the additional "swizzle" field in terms of layout.
It means that the GPU set_encoding/unset_encoding lowering patterns
cover the needs of CPU path. The lowering of contraction ops is
different. CPU lowers it to mmt4d op, while GPU lowers it to multi_mma
op. However, the lowering of contraction is implemented through
attribute interface. Thus, we can have a generic pattern to lower
contraction ops.

To make the review process much easier, the revision is created by 5
commits.

1. It directly creates the MaterializeEncoding pass and copy-paste the
GPU patterns: SetEncodingOpLoweringConversion,
UnSetEncodingOpLoweringConversion, and MaterializeContractionOp. In the
first commit, it also updates the GPU tests to use the new pass.
2. The GPU data-tiling does not support element-wise generic op lowering
atm. The second commit moves the pattern to shared pattern set and bail
out when swizzle is present. This is an NFC for both pipelines.
3. The third commit replaces the existing materialization pass with the
generic pass, and deletes all the legacy passes.
4. The four commit moves the lit tests from `Common/[CPU|GPU]/test` to
`Common/test`.
5. Now there are duplicate patterns for set_encoding, unset_encoding,
and contraction ops lowering. The last commit deletes the legacy
patterns, and move the patterns from MaterializeEncoding.cpp to where
the legacy patterns locate. Furthermore, it renames the file as
`MaterializeEncodingPatterns.cpp`.

The revision retains the MaterializeEncodingIntoNop pass, and add a TODO
item. Because it is still used by MaterializeHomogeneousEncoding pass.
It can be deleted once we deprecate the early materialization path.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index e3513ba..f95b0fa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -125,8 +125,9 @@
         "LinkTuningSpecsPass.cpp",
         "LowerExecutableUsingTransformDialect.cpp",
         "LowerUKernelsToCalls.cpp",
+        "MaterializeEncoding.cpp",
         "MaterializeEncodingIntoNop.cpp",
-        "MaterializeEncodingIntoPackUnPack.cpp",
+        "MaterializeEncodingPatterns.cpp",
         "MaterializeTuningSpecsPass.cpp",
         "MemrefCopyToLinalg.cpp",
         "NormalizeLoopBounds.cpp",
@@ -173,8 +174,10 @@
         ":PassHeaders",
         ":PassesIncGen",
         "//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen",
+        "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
         "//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
         "//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
@@ -183,9 +186,11 @@
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/Encoding/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
+        "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
+        "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
         "//compiler/src/iree/compiler/Dialect/Util/Analysis",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//compiler/src/iree/compiler/Utils",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index adec8aa..af3c557 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -117,8 +117,9 @@
     "LinkTuningSpecsPass.cpp"
     "LowerExecutableUsingTransformDialect.cpp"
     "LowerUKernelsToCalls.cpp"
+    "MaterializeEncoding.cpp"
     "MaterializeEncodingIntoNop.cpp"
-    "MaterializeEncodingIntoPackUnPack.cpp"
+    "MaterializeEncodingPatterns.cpp"
     "MaterializeTuningSpecsPass.cpp"
     "MemrefCopyToLinalg.cpp"
     "NormalizeLoopBounds.cpp"
@@ -203,8 +204,10 @@
     MLIRVectorTransforms
     MLIRViewLikeInterface
     iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen
+    iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
     iree::compiler::Codegen::Dialect::Codegen::Utils
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
     iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
     iree::compiler::Codegen::Interfaces::BufferizationInterfaces
     iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
@@ -213,9 +216,11 @@
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::Encoding::IR
     iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::HAL::Analysis
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::LinalgExt::Transforms
+    iree::compiler::Dialect::Stream::Analysis
     iree::compiler::Dialect::Util::Analysis
     iree::compiler::Dialect::Util::IR
     iree::compiler::Utils
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index f1053da..05fb9be 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -45,7 +45,6 @@
     name = "CommonCPUPasses",
     srcs = [
         "CPULowerToUKernels.cpp",
-        "CPUMaterializeEncodings.cpp",
         "CPUPrepareUkernels.cpp",
         "Passes.cpp",
     ],
@@ -56,16 +55,13 @@
         ":PassHeaders",
         ":PassesIncGen",
         "//compiler/src/iree/compiler/Codegen/Common",
-        "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
         "//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/Encoding/IR",
-        "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
-        "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
         "//runtime/src/iree/builtins/ukernel:exported_bits",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index 75db95e..419c4b0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -42,7 +42,6 @@
     "Passes.h"
   SRCS
     "CPULowerToUKernels.cpp"
-    "CPUMaterializeEncodings.cpp"
     "CPUPrepareUkernels.cpp"
     "Passes.cpp"
   DEPS
@@ -78,16 +77,13 @@
     MLIRVectorTransforms
     iree::builtins::ukernel::exported_bits
     iree::compiler::Codegen::Common
-    iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
     iree::compiler::Codegen::Dialect::Codegen::Utils
     iree::compiler::Codegen::Interfaces::UKernelOpInterface
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::Encoding::IR
-    iree::compiler::Dialect::HAL::Analysis
     iree::compiler::Dialect::HAL::IR
-    iree::compiler::Dialect::Stream::Analysis
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
index 8c73c5b..394de54 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
@@ -13,26 +13,6 @@
 // Common Passes used for CPU-like backends (keep alphabetical)
 //===---------------------------------------------------------------------===//
 
-def CPUMaterializeHostEncodingPass :
-    Pass<"iree-codegen-cpu-materialize-host-encoding", "mlir::ModuleOp"> {
-  let summary = "Convert encoding-specific operations based on target attributes.";
-  let description = [{
-    Examples:
-      encoding.set_encoding   -> tensor.pack
-      encoding.unset_encoding -> tensor.unpack
-      linalg.matmul             -> linalg.mmt4d  "}];
-}
-
-def CPUMaterializeDeviceEncodingPass :
-    InterfacePass<"iree-codegen-cpu-materialize-device-encoding", "mlir::FunctionOpInterface"> {
-  let summary = "Convert encoding-specific operations based on target attributes.";
-  let description = [{
-    Examples:
-      encoding.set_encoding   -> tensor.pack
-      encoding.unset_encoding -> tensor.unpack
-      linalg.matmul             -> linalg.mmt4d  "}];
-}
-
 def CPULowerToUKernelsPass :
     Pass<"iree-codegen-cpu-lower-to-ukernels", ""> {
   let summary =
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/test/BUILD.bazel
index b2d6b91..fe5caa3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/BUILD.bazel
@@ -19,10 +19,8 @@
     srcs = enforce_glob(
         # keep sorted
         [
-            "llvmcpu_materialize_encoding.mlir",
             "lower_to_ukernel_ops.mlir",
             "prepare_ukernels.mlir",
-            "vmvx_materialize_encoding.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/test/CMakeLists.txt
index 3dd9de7..100058f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/CMakeLists.txt
@@ -14,10 +14,8 @@
   NAME
     lit
   SRCS
-    "llvmcpu_materialize_encoding.mlir"
     "lower_to_ukernel_ops.mlir"
     "prepare_ukernels.mlir"
-    "vmvx_materialize_encoding.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index da9aa8a..bf188b6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -93,17 +93,9 @@
     Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn);
 
-/// Pouplates the set of patterns that lowers set_encoding, unset_encoding, and
-/// upstream dialect ops with encoding types to pack/unpack ops.
-void populateMaterializeEncodingIntoPackUnPackPatterns(
-    RewritePatternSet &patterns,
-    MaterializeEncodingTypeConverter &typeConverter,
-    MaterializeEncodingValueFn materializeEncodingValueFn);
-
-/// Pouplates the set of patterns that lowers shape-like operations (e.g., Flow
-/// ops, Hal ops, tensor.empty, linalg.fill, etc) with encoding types to the
-/// same op with materialized shapes.
-void populateShapeIndependentMaterializeEncodingPatterns(
+/// Pouplates the set of patterns that lowers operations with encoding types to
+/// operations without encodings.
+void populateMaterializeEncodingPatterns(
     RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
     MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 128ffa9..6617777 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -65,7 +65,6 @@
         "GPUGreedilyDistributeToThreads.cpp",
         "GPUInferMemorySpace.cpp",
         "GPULowerToUKernels.cpp",
-        "GPUMaterializeEncoding.cpp",
         "GPUMultiBuffering.cpp",
         "GPUNestedLayoutDistributionPatterns.cpp",
         "GPUPackToIntrinsics.cpp",
@@ -107,10 +106,7 @@
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
-        "//compiler/src/iree/compiler/Dialect/Encoding/IR",
-        "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
-        "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
         "//compiler/src/iree/compiler/Utils",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AMDGPUDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 97d3240..2f065df 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -63,7 +63,6 @@
     "GPUGreedilyDistributeToThreads.cpp"
     "GPUInferMemorySpace.cpp"
     "GPULowerToUKernels.cpp"
-    "GPUMaterializeEncoding.cpp"
     "GPUMultiBuffering.cpp"
     "GPUNestedLayoutDistributionPatterns.cpp"
     "GPUPackToIntrinsics.cpp"
@@ -140,10 +139,7 @@
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
     iree::compiler::Codegen::Utils::VectorOpUtils
-    iree::compiler::Dialect::Encoding::IR
-    iree::compiler::Dialect::HAL::Analysis
     iree::compiler::Dialect::HAL::IR
-    iree::compiler::Dialect::Stream::Analysis
     iree::compiler::Utils
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
deleted file mode 100644
index 3253608..0000000
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ /dev/null
@@ -1,398 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Common/EncodingUtils.h"
-#include "iree/compiler/Codegen/Common/GPU/Passes.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
-#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
-#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "iree-codegen-gpu-materialize-encoding"
-
-namespace mlir::iree_compiler {
-
-#define GEN_PASS_DEF_GPUMATERIALIZEDEVICEENCODINGPASS
-#define GEN_PASS_DEF_GPUMATERIALIZEHOSTENCODINGPASS
-#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
-
-using IREE::Codegen::MaterializeEncodingInfo;
-using IREE::Codegen::TileSwizzle;
-
-namespace {
-
-// TODO(hanchung): Delete this pass and rely on tensor-based analysis to
-// materialize encodings based on where tensors are used. This pass is not able
-// to handle that.
-struct GPUMaterializeHostEncodingPass
-    : public impl::GPUMaterializeHostEncodingPassBase<
-          GPUMaterializeHostEncodingPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, tensor::TensorDialect,
-                    linalg::LinalgDialect, IREE::Encoding::IREEEncodingDialect,
-                    IREE::GPU::IREEGPUDialect>();
-  }
-
-  void runOnOperation() override;
-};
-
-struct GPUMaterializeDeviceEncodingPass final
-    : impl::GPUMaterializeDeviceEncodingPassBase<
-          GPUMaterializeDeviceEncodingPass> {
-  using GPUMaterializeDeviceEncodingPassBase::
-      GPUMaterializeDeviceEncodingPassBase;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, tensor::TensorDialect,
-                    linalg::LinalgDialect, IREE::Encoding::IREEEncodingDialect,
-                    IREE::GPU::IREEGPUDialect>();
-  }
-  void runOnOperation() override;
-};
-
-SmallVector<ReassociationIndices>
-getReassociationIndices(int outerDims,
-                        const TileSwizzle::ExpandShapeType &expandShape) {
-  SmallVector<ReassociationIndices> result;
-  int expandedIdx = 0;
-  for (int i = 0; i < outerDims; ++i) {
-    result.push_back({expandedIdx++});
-  }
-  for (auto expandShapeDim : expandShape) {
-    result.push_back({});
-    for (int i = 0, e = expandShapeDim.size(); i < e; ++i) {
-      result.back().push_back(expandedIdx++);
-    }
-  }
-  return result;
-}
-
-/// Convert iree_linalg_ext.set_encoding op to pack + tile swizzling ops. We use
-/// expand_shape + linalg.transpose to represent a tile swizzling op.
-struct GPUSetEncodingOpLoweringConversion
-    : public OpMaterializeEncodingPattern<IREE::Encoding::SetEncodingOp> {
-  using OpMaterializeEncodingPattern<
-      IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern;
-
-  LogicalResult
-  matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
-        getTypeConverter());
-    auto packedValue = lowerSetEncodingOpToPackOp(
-        rewriter, encodingOp, adaptor.getSource(), *converter,
-        this->materializeEncodingValueFn);
-    if (failed(packedValue)) {
-      Type targetType =
-          getTypeConverter()->convertType(encodingOp.getResultType());
-      Value result = rewriter.createOrFold<tensor::CastOp>(
-          encodingOp.getLoc(), targetType, adaptor.getSource());
-      rewriter.replaceOp(encodingOp, result);
-      return success();
-    }
-
-    MaterializeEncodingInfo encodingInfo =
-        converter->getEncodingInfo(encodingOp.getResultType());
-    if (!encodingInfo.swizzle) {
-      rewriter.replaceOp(encodingOp, packedValue.value());
-      return success();
-    }
-
-    Location loc = encodingOp.getLoc();
-
-    // Create expand_shape op to tile the innermost two dimensions.
-    int origRank = encodingOp.getSourceType().getRank();
-    SmallVector<int64_t> expandShapeShape(
-        cast<ShapedType>(packedValue->getType())
-            .getShape()
-            .take_front(origRank));
-    expandShapeShape.append(
-        getExpandedTileShape(encodingInfo.swizzle->expandShape));
-    RankedTensorType expandShapeType =
-        encodingOp.getSourceType().clone(expandShapeShape);
-
-    SmallVector<ReassociationIndices> reassociation =
-        getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
-    auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandShapeType, packedValue.value(), reassociation);
-
-    SmallVector<int64_t> transposePerm =
-        llvm::to_vector(llvm::seq<int64_t>(0, origRank));
-    for (auto perm : encodingInfo.swizzle->permutation) {
-      transposePerm.push_back(origRank + perm);
-    }
-    SmallVector<OpFoldResult> transposeResultDims =
-        tensor::getMixedSizes(rewriter, loc, expandShapeOp.getResult());
-    applyPermutationToVector(transposeResultDims, transposePerm);
-
-    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, transposeResultDims, encodingOp.getSourceType().getElementType());
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, expandShapeOp, emptyTensor, transposePerm);
-    rewriter.replaceOp(encodingOp, transposeOp->getResult(0));
-
-    return success();
-  }
-};
-
-struct GPUUnsetEncodingOpLoweringConversion
-    : public OpMaterializeEncodingPattern<IREE::Encoding::UnsetEncodingOp> {
-  using OpMaterializeEncodingPattern<
-      IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern;
-
-  LogicalResult
-  matchAndRewrite(IREE::Encoding::UnsetEncodingOp unsetEncodingOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
-        getTypeConverter());
-
-    MaterializeEncodingInfo encodingInfo =
-        converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
-    if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
-      Type targetType =
-          getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
-      Value result = rewriter.createOrFold<tensor::CastOp>(
-          unsetEncodingOp.getLoc(), targetType, adaptor.getSource());
-      rewriter.replaceOp(unsetEncodingOp, result);
-      return success();
-    }
-
-    Location loc = unsetEncodingOp.getLoc();
-    Value unpackSrc = adaptor.getSource();
-    if (encodingInfo.swizzle) {
-      int targetRank = unsetEncodingOp.getResultType().getRank();
-      auto srcConvertedType =
-          cast<RankedTensorType>(adaptor.getSource().getType());
-      SmallVector<OpFoldResult> emptyShape =
-          tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
-      emptyShape.resize(targetRank);
-      for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
-        emptyShape.push_back(rewriter.getIndexAttr(i));
-      }
-      auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-          loc, emptyShape, unsetEncodingOp.getSourceType().getElementType());
-
-      SmallVector<int64_t> transposePerm =
-          llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
-      for (auto perm : encodingInfo.swizzle->permutation) {
-        transposePerm.push_back(targetRank + perm);
-      }
-      auto invertedTransposePerm = invertPermutationVector(transposePerm);
-      auto transposeOp = rewriter.create<linalg::TransposeOp>(
-          loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
-
-      SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
-          targetRank, encodingInfo.swizzle->expandShape);
-      SmallVector<int64_t> unpackSrcShape(
-          srcConvertedType.getShape().take_front(targetRank));
-      unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
-                            encodingInfo.innerTileSizes.end());
-      RankedTensorType unpackSrcType =
-          unsetEncodingOp.getResultType().clone(unpackSrcShape);
-      unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
-          loc, unpackSrcType, transposeOp->getResult(0), reassociation);
-    }
-
-    auto unpackedValue = lowerUnsetEncodingToUnpackOp(
-        rewriter, unsetEncodingOp, unpackSrc, *converter,
-        this->materializeEncodingValueFn);
-    if (failed(unpackedValue)) {
-      Type targetType =
-          getTypeConverter()->convertType(unsetEncodingOp.getResultType());
-      Value result = rewriter.createOrFold<tensor::CastOp>(loc, targetType,
-                                                           adaptor.getSource());
-      rewriter.replaceOp(unsetEncodingOp, result);
-      return success();
-    }
-    rewriter.replaceOp(unsetEncodingOp, unpackedValue.value());
-    return success();
-  }
-};
-
-class GPUConvertToMultiMma final
-    : public OpInterfaceConversionPattern<linalg::ContractionOpInterface> {
-public:
-  using OpInterfaceConversionPattern<
-      linalg::ContractionOpInterface>::OpInterfaceConversionPattern;
-
-  GPUConvertToMultiMma(
-      MLIRContext *context,
-      const MaterializeEncodingTypeConverter &typeConverter,
-      MaterializeEncodingValueFn materializeEncodingValueFn = {},
-      PatternBenefit benefit = 1)
-      : OpInterfaceConversionPattern<mlir::linalg::ContractionOpInterface>(
-            typeConverter, context, benefit),
-        materializeEncodingValueFn(materializeEncodingValueFn) {}
-
-  LogicalResult
-  matchAndRewrite(linalg::ContractionOpInterface op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
-        this->getTypeConverter());
-    auto layoutAttr = converter->getLayoutAttr();
-    assert(layoutAttr && "layoutAttr is not set, which is not expected. Are "
-                         "you adding new arch support?");
-    SmallVector<Type> convertedResTypes;
-    auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
-    for (auto init : linalgOp.getDpsInits()) {
-      convertedResTypes.push_back(converter->convertType(init.getType()));
-    }
-    Operation *newOp =
-        layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-
-protected:
-  const MaterializeEncodingValueFn materializeEncodingValueFn;
-};
-
-static LogicalResult
-materializeFuncOpEncodings(FunctionOpInterface funcOp,
-                           IREE::HAL::ExecutableTargetAttr targetAttr) {
-  MLIRContext *ctx = funcOp.getContext();
-  {
-    RewritePatternSet patterns(ctx);
-    IREE::GPU::TargetAttr gpuTargetAttr;
-    if (targetAttr) {
-      gpuTargetAttr = getGPUTargetAttr(targetAttr);
-    } else {
-      gpuTargetAttr = getCLGPUTarget(ctx);
-    }
-    MaterializeEncodingTypeConverter typeConverter(
-        cast<IREE::Codegen::LayoutAttrInterface>(
-            IREE::GPU::GPUEncodingLayoutAttr::get(ctx, gpuTargetAttr)));
-    MaterializeEncodingConversionTarget target(*ctx);
-    MaterializeEncodingValueFn materializeEncodingValueFn =
-        [](RankedTensorType, OpBuilder,
-           Location) -> FailureOr<MaterializeEncodingValueInfo> { return {}; };
-    populateShapeIndependentMaterializeEncodingPatterns(
-        patterns, target, typeConverter, materializeEncodingValueFn);
-
-    patterns.insert<GPUSetEncodingOpLoweringConversion,
-                    GPUUnsetEncodingOpLoweringConversion, GPUConvertToMultiMma>(
-        ctx, typeConverter, materializeEncodingValueFn);
-
-    memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-    if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
-      funcOp.emitOpError("materialization failed");
-      return failure();
-    }
-  }
-
-  // Add patterns to fold pack/unpack ops with pad/extract_slice ops and
-  // resolve dims ops.
-  {
-    RewritePatternSet patterns(ctx);
-    tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
-    tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
-    memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-      funcOp.emitOpError("folding patterns failed");
-      return failure();
-    }
-  }
-
-  return success();
-}
-
-static std::optional<SetVector<IREE::HAL::ExecutableTargetAttr>>
-getFuncExecutableTargetAttrs(FunctionOpInterface funcOp,
-                             IREE::Stream::AffinityAnalysis &affinityAnalysis,
-                             IREE::HAL::DeviceAnalysis &deviceAnalysis) {
-  // Get a set of all unique affinities used by resources within the function.
-  SetVector<IREE::Stream::AffinityAttr> uniqueAffinityAttrs;
-  SmallVector<IREE::Stream::AffinityAttr> lookupAffinityAttrs;
-  funcOp.walk([&](Operation *op) {
-    if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) {
-      uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(),
-                                 lookupAffinityAttrs.end());
-    }
-    lookupAffinityAttrs.clear();
-  });
-
-  // Resolve affinities to executable targets.
-  SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
-  for (auto affinityAttr : uniqueAffinityAttrs) {
-    deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp,
-                                                   executableTargetAttrs);
-  }
-  return executableTargetAttrs;
-}
-
-} // namespace
-
-void GPUMaterializeHostEncodingPass::runOnOperation() {
-  auto moduleOp = getOperation();
-
-  // Run required analysis passes.
-  IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
-  if (failed(affinityAnalysis.run())) {
-    return signalPassFailure();
-  }
-  IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
-  if (failed(deviceAnalysis.run())) {
-    return signalPassFailure();
-  }
-
-  for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
-    // Gather the required executable targets for the function. Note that it's
-    // possible there are more required for ops nested within the function but
-    // this pass is a hack and can't handle that :shrug:.
-    auto executableTargets =
-        getFuncExecutableTargetAttrs(funcOp, affinityAnalysis, deviceAnalysis);
-    if (!executableTargets) {
-      funcOp.emitOpError()
-          << "could not determine executable targets for the function";
-      return signalPassFailure();
-    } else if (executableTargets->empty()) {
-      // Probably no tensors.
-      continue;
-    }
-
-    // HACK: this pass is run on the host _but shouldn't be_. Because it's
-    // run on the host and IREE is a compiler capable of multi-targeting there
-    // may be multiple executable targets at any point in the host program.
-    // This pass can't handle that and assumes it's been checked earlier by
-    // spooky action at a distance. This needs to be fixed.
-    if (executableTargets->size() != 1) {
-      funcOp.emitOpError() << "has multiple executable targets and CPU data "
-                              "tiling isn't built to support that";
-      return signalPassFailure();
-    }
-
-    // Materialize encodings within the function.
-    if (failed(
-            materializeFuncOpEncodings(funcOp, executableTargets->front()))) {
-      return signalPassFailure();
-    }
-  }
-}
-
-void GPUMaterializeDeviceEncodingPass::runOnOperation() {
-  FunctionOpInterface funcOp = getOperation();
-  auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
-  if (failed(materializeFuncOpEncodings(funcOp, targetAttr))) {
-    return signalPassFailure();
-  }
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 2c25e02..ff2b2b9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -247,16 +247,6 @@
   ];
 }
 
-def GPUMaterializeHostEncodingPass :
-    Pass<"iree-codegen-gpu-materialize-host-encoding", "mlir::ModuleOp"> {
-  let summary = "Materialize the encoding for tensor as specified by the backend.";
-}
-
-def GPUMaterializeDeviceEncodingPass :
-    InterfacePass<"iree-codegen-gpu-materialize-device-encoding", "mlir::FunctionOpInterface"> {
-  let summary = "Materialize the encoding for tensor as specified by the backend.";
-}
-
 def GPUTensorTileToSerialLoopsPass :
     InterfacePass<"iree-codegen-gpu-tensor-tile-to-serial-loops", "mlir::FunctionOpInterface"> {
   let summary = "Pass to tile reduction dimensions for certain GPU ops";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 030e6f4..2f3b092 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -32,10 +32,6 @@
             "gpu_infer_memory_space.mlir",
             "gpu_combine_value_barriers.mlir",
             "gpu_lower_to_ukernels.mlir",
-            "gpu_materialize_encoding_gfx908.mlir",
-            "gpu_materialize_encoding_gfx90a.mlir",
-            "gpu_materialize_encoding_gfx942.mlir",
-            "gpu_materialize_encoding_gfx1100.mlir",
             "gpu_nested_layout_contract_amdgpu.mlir",
             "gpu_nested_layout_vector_distribution.mlir",
             "gpu_nested_layout_vector_distribution_step.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index 6d1f540..50be391 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -27,10 +27,6 @@
     "gpu_greedily_distribute_to_threads.mlir"
     "gpu_infer_memory_space.mlir"
     "gpu_lower_to_ukernels.mlir"
-    "gpu_materialize_encoding_gfx1100.mlir"
-    "gpu_materialize_encoding_gfx908.mlir"
-    "gpu_materialize_encoding_gfx90a.mlir"
-    "gpu_materialize_encoding_gfx942.mlir"
     "gpu_nested_layout_contract_amdgpu.mlir"
     "gpu_nested_layout_vector_distribution.mlir"
     "gpu_nested_layout_vector_distribution_step.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
similarity index 63%
rename from compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
rename to compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
index d182649..f1776b9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
@@ -1,47 +1,46 @@
-// Copyright 2023 The IREE Authors
+// Copyright 2024 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/Common/CPU/Passes.h"
 #include "iree/compiler/Codegen/Common/EncodingUtils.h"
+#include "iree/compiler/Codegen/Common/PassUtils.h"
+#include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
 #include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
 #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/MathExtras.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
 
-#define DEBUG_TYPE "cpu-materialize-encoding"
+#define DEBUG_TYPE "iree-codegen--materialize-encoding"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 namespace mlir::iree_compiler {
 
-using IREE::Codegen::MaterializeEncodingInfo;
-using IREE::Codegen::TileMxNxK;
+#define GEN_PASS_DEF_MATERIALIZEDEVICEENCODINGPASS
+#define GEN_PASS_DEF_MATERIALIZEHOSTENCODINGPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
 
-#define GEN_PASS_DEF_CPUMATERIALIZEDEVICEENCODINGPASS
-#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
-#include "iree/compiler/Codegen/Common/CPU/Passes.h.inc"
+using namespace IREE::Encoding;
+
+namespace {
 
 static FailureOr<MaterializeEncodingValueInfo>
 chooseDynamicEncodingInfoVMVXMicrokernels(RankedTensorType tensorType,
@@ -64,33 +63,46 @@
 
 static LogicalResult
 materializeFuncOpEncodings(FunctionOpInterface funcOp,
-                           IREE::HAL::ExecutableTargetAttr targetAttr) {
+                           IREE::HAL::ExecutableTargetAttr targetAttr,
+                           bool testCLGPUTarget = false) {
   MLIRContext *ctx = funcOp.getContext();
-  RewritePatternSet materializeEncodingPattern(ctx);
-  DictionaryAttr targetConfig = targetAttr.getConfiguration();
-  IREE::Codegen::LayoutAttrInterface layoutAttr;
-  if (isVMVXBackend(targetAttr)) {
-    LDBG("Select VMVXEncodingLayoutAttr attribute as the layout attribute.");
-    layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
-        IREE::CPU::VMVXEncodingLayoutAttr::get(ctx, targetConfig));
-  } else {
-    LDBG("Select CPUEncodingLayoutAttr attribute as the layout attribute.");
-    layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
-        IREE::CPU::CPUEncodingLayoutAttr::get(ctx, targetConfig));
-  }
-  MaterializeEncodingTypeConverter typeConverter(layoutAttr);
-  MaterializeEncodingConversionTarget target(*ctx);
-  auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
-  populateMaterializeEncodingIntoPackUnPackPatterns(
-      materializeEncodingPattern, typeConverter, materializeEncodingValueFn);
-  populateShapeIndependentMaterializeEncodingPatterns(
-      materializeEncodingPattern, target, typeConverter,
-      materializeEncodingValueFn);
+  {
+    RewritePatternSet patterns(ctx);
+    IREE::Codegen::LayoutAttrInterface layoutAttr;
+    if (isVMVXBackend(targetAttr)) {
+      LDBG("Select VMVXEncodingLayoutAttr attribute as the layout attribute.");
+      layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
+          IREE::CPU::VMVXEncodingLayoutAttr::get(
+              ctx, targetAttr.getConfiguration()));
+    } else if (isLLVMCPUBackend(targetAttr)) {
+      LDBG("Select CPUEncodingLayoutAttr attribute as the layout attribute.");
+      layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
+          IREE::CPU::CPUEncodingLayoutAttr::get(ctx,
+                                                targetAttr.getConfiguration()));
+    } else if (isROCMBackend(targetAttr)) {
+      LDBG("Select GPUEncodingLayoutAttr attribute as the layout attribute.");
+      layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
+          IREE::GPU::GPUEncodingLayoutAttr::get(ctx,
+                                                getGPUTargetAttr(targetAttr)));
+    } else if (testCLGPUTarget) {
+      LDBG("Select GPUEncodingLayoutAttr attribute as the layout attribute. "
+           "(testCLGPUTarget)");
+      layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
+          IREE::GPU::GPUEncodingLayoutAttr::get(ctx, getCLGPUTarget(ctx)));
+    } else {
+      LDBG("Select EncodingNopLayoutAttr attribute as the layout attribute.");
+      layoutAttr = IREE::Codegen::EncodingNopLayoutAttr::get(ctx);
+    }
+    MaterializeEncodingTypeConverter typeConverter(layoutAttr);
+    MaterializeEncodingConversionTarget target(*ctx);
+    auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
+    populateMaterializeEncodingPatterns(patterns, target, typeConverter,
+                                        materializeEncodingValueFn);
 
-  if (failed(applyPartialConversion(funcOp, target,
-                                    std::move(materializeEncodingPattern)))) {
-    funcOp.emitOpError("materialization failed");
-    return failure();
+    if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
+      funcOp.emitOpError("materialization failed");
+      return failure();
+    }
   }
 
   // Add patterns to fold pack/unpack ops with pad/extract_slice ops and
@@ -138,13 +150,13 @@
   return executableTargetAttrs;
 }
 
-struct CPUMaterializeHostEncodingPass
-    : public impl::CPUMaterializeHostEncodingPassBase<
-          CPUMaterializeHostEncodingPass> {
+struct MaterializeHostEncodingPass
+    : public impl::MaterializeHostEncodingPassBase<
+          MaterializeHostEncodingPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, tensor::TensorDialect,
-                IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
+    registry.insert<arith::ArithDialect, tensor::TensorDialect,
+                    IREE::Codegen::IREECodegenDialect,
+                    IREE::CPU::IREECPUDialect, IREE::GPU::IREEGPUDialect>();
   }
 
   void runOnOperation() override {
@@ -199,22 +211,27 @@
 // that. It should _not_ be running on both - target-specific codegen passes
 // are not allowed on host programs and it's a big violation of layering that
 // this exists.
-struct CPUMaterializeDeviceEncodingPass
-    : public impl::CPUMaterializeDeviceEncodingPassBase<
-          CPUMaterializeDeviceEncodingPass> {
+struct MaterializeDeviceEncodingPass
+    : public impl::MaterializeDeviceEncodingPassBase<
+          MaterializeDeviceEncodingPass> {
+  using impl::MaterializeDeviceEncodingPassBase<
+      MaterializeDeviceEncodingPass>::MaterializeDeviceEncodingPassBase;
+
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, tensor::TensorDialect,
-                IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
+    registry.insert<arith::ArithDialect, tensor::TensorDialect,
+                    IREE::Codegen::IREECodegenDialect,
+                    IREE::CPU::IREECPUDialect, IREE::GPU::IREEGPUDialect>();
   }
 
   void runOnOperation() override {
     auto funcOp = getOperation();
     auto executableTargetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
-    if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr))) {
+    if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr,
+                                          testCLGPUTarget))) {
       return signalPassFailure();
     }
   }
 };
+} // namespace
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
index 4de4b45..d93cb98 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
@@ -48,11 +48,9 @@
     MaterializeEncodingTypeConverter typeConverter(
         IREE::Codegen::EncodingNopLayoutAttr::get(context));
     MaterializeEncodingConversionTarget target(*context);
-    populateMaterializeEncodingIntoPackUnPackPatterns(
-        materializeEncodingPattern, typeConverter, materializeEncodingValueFn);
-    populateShapeIndependentMaterializeEncodingPatterns(
-        materializeEncodingPattern, target, typeConverter,
-        materializeEncodingValueFn);
+    populateMaterializeEncodingPatterns(materializeEncodingPattern, target,
+                                        typeConverter,
+                                        materializeEncodingValueFn);
 
     if (failed(applyPartialConversion(operation, target,
                                       std::move(materializeEncodingPattern)))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
similarity index 84%
rename from compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
rename to compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
index 087d91d..cd3d27e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
@@ -32,6 +32,7 @@
 namespace mlir::iree_compiler {
 
 using IREE::Codegen::MaterializeEncodingInfo;
+using IREE::Codegen::TileSwizzle;
 
 //===---------------------------------------------------------------------===//
 // Utility methods
@@ -237,6 +238,10 @@
     return rewriter.notifyMatchFailure(
         genericOp, "MaterializeEncodingInfo failed for output");
   }
+  if (outMaterializeEncodingInfo.swizzle) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "generic op lowering does not support swizzle yet");
+  }
 
   auto convertedResultType =
       cast<RankedTensorType>(convertedOutputOperands[0].getType());
@@ -561,60 +566,6 @@
 // the core conversion utilities.
 //===---------------------------------------------------------------------===//
 
-/// Convert `set_encoding` op to `pack` op.
-struct SetEncodingOpToPackOpConversion
-    : public OpMaterializeEncodingPattern<IREE::Encoding::SetEncodingOp> {
-  using OpMaterializeEncodingPattern<
-      IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern;
-
-  LogicalResult
-  matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
-        getTypeConverter());
-    auto packOp = lowerSetEncodingOpToPackOp(rewriter, encodingOp,
-                                             adaptor.getSource(), *converter,
-                                             this->materializeEncodingValueFn);
-    if (failed(packOp)) {
-      Type targetType =
-          getTypeConverter()->convertType(encodingOp.getResultType());
-      Value result = rewriter.createOrFold<tensor::CastOp>(
-          encodingOp.getLoc(), targetType, adaptor.getSource());
-      rewriter.replaceOp(encodingOp, result);
-      return success();
-    }
-    rewriter.replaceOp(encodingOp, packOp.value());
-    return success();
-  }
-};
-
-/// Convert `unset_encoding` op to `unpack` op.
-struct UnsetEncodingOpToUnPackOpConversion
-    : public OpMaterializeEncodingPattern<IREE::Encoding::UnsetEncodingOp> {
-  using OpMaterializeEncodingPattern<
-      IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern;
-
-  LogicalResult
-  matchAndRewrite(IREE::Encoding::UnsetEncodingOp encodingOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
-        this->getTypeConverter());
-    auto unpackedValue = lowerUnsetEncodingToUnpackOp(
-        rewriter, encodingOp, adaptor.getSource(), *converter,
-        this->materializeEncodingValueFn);
-    if (failed(unpackedValue)) {
-      Type targetType =
-          getTypeConverter()->convertType(encodingOp.getResultType());
-      Value result = rewriter.createOrFold<tensor::CastOp>(
-          encodingOp.getLoc(), targetType, adaptor.getSource());
-      rewriter.replaceOp(encodingOp, result);
-      return success();
-    }
-    rewriter.replaceOp(encodingOp, unpackedValue.value());
-    return success();
-  }
-};
-
 /// Generic pattern to convert operation that is in Destination Passing Style.
 template <typename OpTy>
 struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
@@ -685,6 +636,166 @@
   }
 };
 
+static SmallVector<ReassociationIndices>
+getReassociationIndices(int outerDims,
+                        const TileSwizzle::ExpandShapeType &expandShape) {
+  SmallVector<ReassociationIndices> result;
+  int expandedIdx = 0;
+  for (int i = 0; i < outerDims; ++i) {
+    result.push_back({expandedIdx++});
+  }
+  for (auto expandShapeDim : expandShape) {
+    result.push_back({});
+    for (int i = 0, e = expandShapeDim.size(); i < e; ++i) {
+      result.back().push_back(expandedIdx++);
+    }
+  }
+  return result;
+}
+
+/// Convert iree_linalg_ext.set_encoding op to pack + tile swizzling ops. We use
+/// expand_shape + linalg.transpose to represent a tile swizzling op.
+struct SetEncodingOpLoweringConversion
+    : public OpMaterializeEncodingPattern<IREE::Encoding::SetEncodingOp> {
+  using OpMaterializeEncodingPattern<
+      IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern;
+
+  LogicalResult
+  matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        getTypeConverter());
+    auto packedValue = lowerSetEncodingOpToPackOp(
+        rewriter, encodingOp, adaptor.getSource(), *converter,
+        this->materializeEncodingValueFn);
+    if (failed(packedValue)) {
+      Type targetType =
+          getTypeConverter()->convertType(encodingOp.getResultType());
+      Value result = rewriter.createOrFold<tensor::CastOp>(
+          encodingOp.getLoc(), targetType, adaptor.getSource());
+      rewriter.replaceOp(encodingOp, result);
+      return success();
+    }
+
+    MaterializeEncodingInfo encodingInfo =
+        converter->getEncodingInfo(encodingOp.getResultType());
+    if (!encodingInfo.swizzle) {
+      rewriter.replaceOp(encodingOp, packedValue.value());
+      return success();
+    }
+
+    Location loc = encodingOp.getLoc();
+
+    // Create expand_shape op to tile the innermost two dimensions.
+    int origRank = encodingOp.getSourceType().getRank();
+    SmallVector<int64_t> expandShapeShape(
+        cast<ShapedType>(packedValue->getType())
+            .getShape()
+            .take_front(origRank));
+    expandShapeShape.append(
+        getExpandedTileShape(encodingInfo.swizzle->expandShape));
+    RankedTensorType expandShapeType =
+        encodingOp.getSourceType().clone(expandShapeShape);
+
+    SmallVector<ReassociationIndices> reassociation =
+        getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
+    auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandShapeType, packedValue.value(), reassociation);
+
+    SmallVector<int64_t> transposePerm =
+        llvm::to_vector(llvm::seq<int64_t>(0, origRank));
+    for (auto perm : encodingInfo.swizzle->permutation) {
+      transposePerm.push_back(origRank + perm);
+    }
+    SmallVector<OpFoldResult> transposeResultDims =
+        tensor::getMixedSizes(rewriter, loc, expandShapeOp.getResult());
+    applyPermutationToVector(transposeResultDims, transposePerm);
+
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, transposeResultDims, encodingOp.getSourceType().getElementType());
+    auto transposeOp = rewriter.create<linalg::TransposeOp>(
+        loc, expandShapeOp, emptyTensor, transposePerm);
+    rewriter.replaceOp(encodingOp, transposeOp->getResult(0));
+
+    return success();
+  }
+};
+
+struct UnsetEncodingOpLoweringConversion
+    : public OpMaterializeEncodingPattern<IREE::Encoding::UnsetEncodingOp> {
+  using OpMaterializeEncodingPattern<
+      IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern;
+
+  LogicalResult
+  matchAndRewrite(IREE::Encoding::UnsetEncodingOp unsetEncodingOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        getTypeConverter());
+
+    MaterializeEncodingInfo encodingInfo =
+        converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
+    if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
+      Type targetType =
+          getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
+      Value result = rewriter.createOrFold<tensor::CastOp>(
+          unsetEncodingOp.getLoc(), targetType, adaptor.getSource());
+      rewriter.replaceOp(unsetEncodingOp, result);
+      return success();
+    }
+
+    Location loc = unsetEncodingOp.getLoc();
+    Value unpackSrc = adaptor.getSource();
+    if (encodingInfo.swizzle) {
+      int targetRank = unsetEncodingOp.getResultType().getRank();
+      auto srcConvertedType =
+          cast<RankedTensorType>(adaptor.getSource().getType());
+      SmallVector<OpFoldResult> emptyShape =
+          tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
+      emptyShape.resize(targetRank);
+      for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
+        emptyShape.push_back(rewriter.getIndexAttr(i));
+      }
+      auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+          loc, emptyShape, unsetEncodingOp.getSourceType().getElementType());
+
+      SmallVector<int64_t> transposePerm =
+          llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
+      for (auto perm : encodingInfo.swizzle->permutation) {
+        transposePerm.push_back(targetRank + perm);
+      }
+      auto invertedTransposePerm = invertPermutationVector(transposePerm);
+      auto transposeOp = rewriter.create<linalg::TransposeOp>(
+          loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
+
+      SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
+          targetRank, encodingInfo.swizzle->expandShape);
+      SmallVector<int64_t> unpackSrcShape(
+          srcConvertedType.getShape().take_front(targetRank));
+      unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
+                            encodingInfo.innerTileSizes.end());
+      RankedTensorType unpackSrcType =
+          unsetEncodingOp.getResultType().clone(unpackSrcShape);
+      unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
+          loc, unpackSrcType, transposeOp->getResult(0), reassociation);
+    }
+
+    auto unpackedValue = lowerUnsetEncodingToUnpackOp(
+        rewriter, unsetEncodingOp, unpackSrc, *converter,
+        this->materializeEncodingValueFn);
+    if (failed(unpackedValue)) {
+      Type targetType =
+          getTypeConverter()->convertType(unsetEncodingOp.getResultType());
+      Value result = rewriter.createOrFold<tensor::CastOp>(loc, targetType,
+                                                           adaptor.getSource());
+      rewriter.replaceOp(unsetEncodingOp, result);
+      return success();
+    }
+    rewriter.replaceOp(unsetEncodingOp, unpackedValue.value());
+    return success();
+  }
+};
+
 /// Pattern to convert contraction operations.
 class MaterializeContractionOp
     : public OpInterfaceConversionPattern<linalg::LinalgOp> {
@@ -726,21 +837,7 @@
 
 } // namespace
 
-void populateMaterializeEncodingIntoPackUnPackPatterns(
-    RewritePatternSet &patterns,
-    MaterializeEncodingTypeConverter &typeConverter,
-    MaterializeEncodingValueFn materializeEncodingValueFn) {
-  MLIRContext *context = patterns.getContext();
-  // TODO(hanchung): Move the generic op pattern to ShapeIndependent category
-  // after we add the support for tile swizzling variants.
-  patterns.insert<MaterializeDPSOperation<linalg::GenericOp>,
-                  MaterializeContractionOp, SetEncodingOpToPackOpConversion,
-                  UnsetEncodingOpToUnPackOpConversion>(
-      context, typeConverter, materializeEncodingValueFn);
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-}
-
-void populateShapeIndependentMaterializeEncodingPatterns(
+void populateMaterializeEncodingPatterns(
     RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
     MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
@@ -767,7 +864,10 @@
       });
 
   patterns.insert<
+      MaterializeContractionOp, SetEncodingOpLoweringConversion,
+      UnsetEncodingOpLoweringConversion,
       MaterializeDPSOperation<linalg::FillOp>,
+      MaterializeDPSOperation<linalg::GenericOp>,
       MaterializeOperation<tensor::EmptyOp>, MaterializeOptimizationBarrierOp,
       MaterializeFlowDispatchTensorLoadOp, MaterializeFlowDispatchTensorStoreOp,
       MaterializeInterfaceBindingEncoding>(context, typeConverter,
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 5571aba..5cc0d55 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -431,6 +431,21 @@
   let summary = "Lower micro-kernel wrapper ops into function calls";
 }
 
+def MaterializeHostEncodingPass :
+    Pass<"iree-codegen-materialize-host-encoding", "mlir::ModuleOp"> {
+  let summary = "Materialize the encoding for tensor as specified by the backend.";
+}
+
+def MaterializeDeviceEncodingPass :
+    InterfacePass<"iree-codegen-materialize-device-encoding", "mlir::FunctionOpInterface"> {
+  let summary = "Materialize the encoding for tensor as specified by the backend.";
+  let options = [
+    Option<"testCLGPUTarget", "test-cl-gpu-target", "bool", /*default=*/"false",
+           "Flag used for lit-testing GPU target only. Not for general usage">,
+  ];
+}
+
+// TODO(hanchung): Remove the pass after we deprecate MaterializeHomogeneousEncodingsPass.
 def MaterializeEncodingIntoNopPass :
     InterfacePass<"iree-codegen-materialize-encoding-into-nop", "mlir::FunctionOpInterface"> {
   let summary = "Drop the encodings from tensor types with encodings.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 5de2e3d..f0652d2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -47,12 +47,17 @@
             "fold_tensor_extract_op.mlir",
             "forop_canonicalization.mlir",
             "generic_vectorization.mlir",
+            "gpu_materialize_encoding_gfx1100.mlir",
+            "gpu_materialize_encoding_gfx908.mlir",
+            "gpu_materialize_encoding_gfx90a.mlir",
+            "gpu_materialize_encoding_gfx942.mlir",
             "hoist_statically_bound_allocations.mlir",
             "hoist_unrolled_vector_extract_insert_slice.mlir",
             "iree_comprehensive_bufferize.mlir",
             "iree_expand_strided_metadata.mlir",
             "iree_loop_invariant_code_motion.mlir",
             "link_tuning_specs.mlir",
+            "llvmcpu_materialize_encoding.mlir",
             "lower_ukernel_to_calls.mlir",
             "materialize_encoding_into_nop.mlir",
             "materialize_tuning_specs.mlir",
@@ -74,8 +79,8 @@
             "replace_slow_min_max_ops.mlir",
             "strip_compilation_info.mlir",
             "test_partitionable_loops_interface.mlir",
-            "tile_and_distribute_to_workgroups_func_scope.mlir",
             "tile_and_distribute_to_workgroups.mlir",
+            "tile_and_distribute_to_workgroups_func_scope.mlir",
             "tile_and_distribute_workgroups_using_forall.mlir",
             "tile_large_tensors.mlir",
             "transform_buffer_opt.mlir",
@@ -88,10 +93,11 @@
             "type_propagation.mlir",
             "type_propagation_packing.mlir",
             "unroll_annotated_loops.mlir",
+            "vector_layout_analysis.mlir",
             "vectorize_memref_copy.mlir",
             "vectorize_tensor_pad.mlir",
-            "vector_layout_analysis.mlir",
             "verify_workgroup_distribution.mlir",
+            "vmvx_materialize_encoding.mlir",
         ],
         include = ["*.mlir"],
         exclude = [
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 4dc774c..2d707f6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -43,12 +43,17 @@
     "fold_tensor_extract_op.mlir"
     "forop_canonicalization.mlir"
     "generic_vectorization.mlir"
+    "gpu_materialize_encoding_gfx1100.mlir"
+    "gpu_materialize_encoding_gfx908.mlir"
+    "gpu_materialize_encoding_gfx90a.mlir"
+    "gpu_materialize_encoding_gfx942.mlir"
     "hoist_statically_bound_allocations.mlir"
     "hoist_unrolled_vector_extract_insert_slice.mlir"
     "iree_comprehensive_bufferize.mlir"
     "iree_expand_strided_metadata.mlir"
     "iree_loop_invariant_code_motion.mlir"
     "link_tuning_specs.mlir"
+    "llvmcpu_materialize_encoding.mlir"
     "lower_ukernel_to_calls.mlir"
     "materialize_encoding_into_nop.mlir"
     "materialize_tuning_specs.mlir"
@@ -88,6 +93,7 @@
     "vectorize_memref_copy.mlir"
     "vectorize_tensor_pad.mlir"
     "verify_workgroup_distribution.mlir"
+    "vmvx_materialize_encoding.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
index bb0c610..645fd71 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding{test-cl-gpu-target}))" \
 // RUN:   --iree-gpu-test-target=gfx1100 \
 // RUN:   --split-input-file %s | FileCheck %s
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
index 4fca563..a9fc2bc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding{test-cl-gpu-target}))" \
 // RUN:   --iree-gpu-test-target=gfx908 \
 // RUN:   --split-input-file %s | FileCheck %s
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
index cc9cd9d..89fe357 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding{test-cl-gpu-target}))" \
 // RUN:   --iree-gpu-test-target=gfx90a \
 // RUN:   --split-input-file %s | FileCheck %s
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx942.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx942.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
index 3338de9..2544fc1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding{test-cl-gpu-target}))" \
 // RUN:   --iree-gpu-test-target=gfx942 \
 // RUN:   --split-input-file %s | FileCheck %s
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/test/llvmcpu_materialize_encoding.mlir
similarity index 96%
rename from compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/llvmcpu_materialize_encoding.mlir
index 553c134..25b69a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/llvmcpu_materialize_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
@@ -6,7 +6,7 @@
 ]>
 #encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, bf16], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 1, 16, 16>>
 func.func @set_encoding_with_padding_semantics_bf16_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 }{
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1000xbf16>>
@@ -44,7 +44,7 @@
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 #encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @set_encoding_7x7x7_matmul_LHS() attributes {
-   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
+   hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
 } {
   %c0 = arith.constant 0 : index
   %8 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<7x7xf32>>
@@ -74,7 +74,7 @@
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 #encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @set_encoding_128x80x32_batch_matmul_LHS() attributes {
-   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
+   hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
 } {
   %c0 = arith.constant 0 : index
   %8 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x80x32xf32>>
@@ -105,7 +105,7 @@
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 #encoding = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @set_encoding_128x32x320_batch_matmul_RHS() attributes {
-   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
+   hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
 } {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
@@ -138,7 +138,7 @@
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 #encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @unset_encoding_128x80x320_batch_matmul_RESULT() attributes {
-   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
+   hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
 } {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
@@ -176,7 +176,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @pack_gemm_fill_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> attributes {
-   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
+   hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx,+avx2,+fma"}>
 } {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -224,7 +224,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 1, 16>>
 func.func @matvec_shaped_matmul_lowering_f32f32f32_aarch64(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<16x16xf32>
@@ -257,7 +257,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -323,7 +323,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], round_dims_to = array<i64: 16, 1, 16>>
 func.func @matvec_lowering_f32f32f32_aarch64(%arg0: tensor<16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> tensor<16xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %3 = iree_encoding.set_encoding %arg0 : tensor<16x16xf32> -> tensor<16x16xf32, #encoding_lhs>
@@ -352,7 +352,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 func.func @matvec_lowering_f32f32f32_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
@@ -414,7 +414,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f16f16f16_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -485,7 +485,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_x86_64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -557,7 +557,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_x86_64_avx2() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -628,7 +628,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -699,7 +699,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f16f16f32_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -770,7 +770,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f16f16f16_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -841,7 +841,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_bf16bf16f32_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -912,7 +912,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, bf16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, bf16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_bf16bf16bf16_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -983,7 +983,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_bf16bf16f32_x86_64_avx512bf16() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1056,7 +1056,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, bf16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, bf16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_bf16bf16bf16_x86_64_avx512bf16() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1129,7 +1129,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f16f16_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1202,7 +1202,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f16, f16], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f16f16_x86_64_avx512f() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f,+avx512bf16"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1276,7 +1276,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1344,7 +1344,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_aarch64_dotprod() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1417,7 +1417,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_aarch64_i8mm() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod,+i8mm", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod,+i8mm", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1489,7 +1489,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i4i32_aarch64() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1563,7 +1563,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i4i32_aarch64_dotprod() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1635,7 +1635,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i4i32_aarch64_i8mm() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod,+i8mm", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="aarch64-xyz-xyz", cpu_features="+dotprod,+i8mm", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1704,7 +1704,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_aarch64_sve(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>, %acc: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {cpu_features = "+sve", target_triple="aarch64-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {cpu_features = "+sve", target_triple="aarch64-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1736,7 +1736,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_f32f32f32_riscv(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>, %acc: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="riscv32-xyz-xyz"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="riscv32-xyz-xyz"}>
 } {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1772,7 +1772,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_riscv32_ukernel() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="riscv32-xyz-xyz", ukernels = "all"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="riscv32-xyz-xyz", ukernels = "all"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1845,7 +1845,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_x86_64_avx2() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx2"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx2"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1918,7 +1918,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_x86_64_avx512bw() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512bw"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512bw"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -1991,7 +1991,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i8i8i32_x86_64_avx512vnni() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -2059,7 +2059,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 func.func @extend_batch_vecmat_explicit_unit_dim(%arg0: tensor<32x1x128xi8>, %arg1: tensor<32x128x11008xi8>) -> tensor<32x1x11008xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %4 = iree_encoding.set_encoding %arg0 : tensor<32x1x128xi8> -> tensor<32x1x128xi8, #encoding_lhs>
@@ -2122,7 +2122,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i16, i16, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i16, i16, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_lowering_i16i16i32_x86_64_avx2() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx2"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx2"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -2195,7 +2195,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i16, ui4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i16, ui4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
 func.func @matmul_lowering_i16ui4i32_x86_64_avx512vnni() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0 = arith.constant 0 : index
   %M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
@@ -2263,7 +2263,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 func.func @vecmat(%arg0: tensor<128xi8>, %arg1: tensor<128x11008xi8>) -> tensor<11008xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %4 = iree_encoding.set_encoding %arg0 : tensor<128xi8> -> tensor<128xi8, #encoding_lhs>
@@ -2325,7 +2325,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 func.func @matvec(%arg0: tensor<11008x128xi8>, %arg1: tensor<128xi8>) -> tensor<11008xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %4 = iree_encoding.set_encoding %arg0 : tensor<11008x128xi8> -> tensor<11008x128xi8, #encoding_lhs>
@@ -2387,7 +2387,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 1, 16>>
 func.func @matvec_with_narrow_M(%arg0: tensor<15x128xi8>, %arg1: tensor<128xi8>) -> tensor<15xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %4 = iree_encoding.set_encoding %arg0 : tensor<15x128xi8> -> tensor<15x128xi8, #encoding_lhs>
@@ -2450,7 +2450,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 16, 16>>
 func.func @batch_vecmat(%arg0: tensor<32x128xi8>, %arg1: tensor<32x128x11008xi8>) -> tensor<32x11008xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %4 = iree_encoding.set_encoding %arg0 : tensor<32x128xi8> -> tensor<32x128xi8, #encoding_lhs>
@@ -2509,7 +2509,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 1, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 1, 16>>
 func.func @batch_matvec(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<32x11008x128xi8>
   %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<32x128xi8>
@@ -2535,7 +2535,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_transpose_a_f32f32f32(%arg0: tensor<256x128xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c256 = arith.constant 256 : index
   %c128 = arith.constant 128 : index
@@ -2574,7 +2574,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @matmul_transpose_b_f32f32f32(%arg0: tensor<128x256xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c128 = arith.constant 128 : index
   %c256 = arith.constant 256 : index
@@ -2612,7 +2612,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @batch_matmul_transpose_a_f32f32f32(%arg0: tensor<2x256x128xf32>, %arg1: tensor<2x256x512xf32>, %arg2: tensor<2x128x512xf32>) -> tensor<2x128x512xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c2 = arith.constant 2 : index
   %c256 = arith.constant 256 : index
@@ -2651,7 +2651,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 16, 16, 16>>
 func.func @batch_matmul_transpose_b_f32f32f32(%arg0: tensor<2x128x256xf32>, %arg1: tensor<2x512x256xf32>, %arg2: tensor<2x128x512xf32>) -> tensor<2x128x512xf32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c2 = arith.constant 2 : index
   %c128 = arith.constant 128 : index
@@ -2690,7 +2690,7 @@
 #encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i16, ui4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 32, 32>>
 #encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i16, ui4, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 1, 32, 32>>
 func.func @generic_batch_vecmat_transposed_i16u4i32(%arg0: tensor<32x128xi16>, %arg1: tensor<4096x32x128xi4>, %arg2: tensor<4096x32xi32>) -> tensor<4096x32xi32> attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
   %c0_i32 = arith.constant 0 : i32
   %c0_i4 = arith.constant 0 : i4
@@ -2747,7 +2747,7 @@
 #encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 #encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 func.func @dequantization() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
@@ -2802,7 +2802,7 @@
 #encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 #encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 func.func @broadcast_batch() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
@@ -2841,7 +2841,7 @@
 #encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 #encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1)>, round_dims_to = array<i64: 16, 16, 16>>
 func.func @broadcast_M() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
@@ -2880,7 +2880,7 @@
 #encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 #encoding_bcast = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 func.func @broadcast_N() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
@@ -2919,7 +2919,7 @@
 #encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 #encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
 func.func @broadcast_K() attributes {
-  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+  hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
 } {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vmvx_materialize_encoding.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
rename to compiler/src/iree/compiler/Codegen/Common/test/vmvx_materialize_encoding.mlir
index 85dd416..2f3b91f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vmvx_materialize_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
   #hal.pipeline.binding<storage_buffer>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 76b2745..1d2b66e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -788,7 +788,7 @@
       // TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added
       // way to late and should insted be be done during lowering to LLVM.
       .addPass(createExpandF16OpToF32Pass)
-      .addPass(createCPUMaterializeDeviceEncodingPass)
+      .addPass(createMaterializeDeviceEncodingPass)
       // TODO: Remove the following pass the plumb support for
       // #hal.descriptor_type memory space through the stack.
       .addPass(createEraseHALDescriptorTypeFromMemRefPass);
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index f17a353..812bc9b 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -161,6 +161,10 @@
   return "unknown";
 }
 
+bool isLLVMCPUBackend(IREE::HAL::ExecutableTargetAttr targetAttr) {
+  return targetAttr && targetAttr.getBackend().getValue() == "llvm-cpu";
+}
+
 bool isVMVXBackend(IREE::HAL::ExecutableTargetAttr targetAttr) {
   return targetAttr && targetAttr.getBackend().getValue().starts_with("vmvx");
 }
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
index d8f96de..ea3d069 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
@@ -61,9 +61,8 @@
 const char *getIreeArchNameForTargetTriple(llvm::Triple triple);
 
 /// Methods to get target information.
+bool isLLVMCPUBackend(IREE::HAL::ExecutableTargetAttr targetAttr);
 bool isVMVXBackend(IREE::HAL::ExecutableTargetAttr targetAttr);
-
-/// Methods to get target information.
 bool isROCMBackend(IREE::HAL::ExecutableTargetAttr targetAttr);
 
 // Returns true if the ukernel with given `ukernelName` is enabled.
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index 00c5c9f..a196e31 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -44,7 +44,7 @@
   }
   modulePassManager.addPass(createMaterializeUserConfigsPass());
   FunctionLikeNest(modulePassManager)
-      .addPass(createCPUMaterializeDeviceEncodingPass)
+      .addPass(createMaterializeDeviceEncodingPass)
       // TODO: Remove the following pass the plumb support for
       // #hal.descriptor_type memory space through the stack.
       .addPass(createEraseHALDescriptorTypeFromMemRefPass);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index d85310e..50ff8a6 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -76,8 +76,6 @@
         ":PassHeaders",
         ":PassesIncGen",
         "//compiler/src/iree/compiler/Codegen/Common",
-        "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
-        "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
         "//compiler/src/iree/compiler/Dialect/Encoding/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 9ca16ee..6650602 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -91,8 +91,6 @@
     MLIRTransformUtils
     MLIRTransforms
     iree::compiler::Codegen::Common
-    iree::compiler::Codegen::Common::CPU::CommonCPUPasses
-    iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
     iree::compiler::Dialect::Encoding::IR
     iree::compiler::Dialect::Flow::Conversion::TensorToFlow
diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
index adcc129..f7aeb82 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
@@ -4,8 +4,6 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/Common/CPU/Passes.h"
-#include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
 #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
@@ -82,10 +80,10 @@
     // Only llvm-cpu and rocm backends handle encodings for now, others just go
     // with nop.
     if (executableTarget.getBackend() == "llvm-cpu") {
-      passManager.addPass(createCPUMaterializeHostEncodingPass());
+      passManager.addPass(createMaterializeHostEncodingPass());
     } else if (clEnableExperimentalRocmDataTiling &&
                executableTarget.getBackend() == "rocm") {
-      passManager.addPass(createGPUMaterializeHostEncodingPass());
+      passManager.addPass(createMaterializeHostEncodingPass());
       FunctionLikeNest(passManager).addPass([&]() {
         return createDecomposePackUnPackOpsPass(
             DecomposePackUnPackOpsPassOptions{/*tileOuterToOne=*/false,