[codegen][rocdl] Remove ROCDLKernelConfig and ROCDLSelectLoweringStrategy (#21820)

This patch removes removes:
- ROCDLKernelConfig.h
- ROCDLKernelConfig.cpp
- ROCDLSelectLoweringStrategy.cpp

The reason for removal is that they are not being used by any pipeline,
and are not being properly tested.

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 8702fd7..1e8bbca 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -113,10 +113,8 @@
         "ROCDLAnnotateKernelForTranslation.cpp",
         "ROCDLBufferInstructionsOptimization.cpp",
         "ROCDLConfigureBufferInstructions.cpp",
-        "ROCDLKernelConfig.cpp",
         "ROCDLLowerExecutableTarget.cpp",
         "ROCDLPrefetching.cpp",
-        "ROCDLSelectLoweringStrategy.cpp",
         "TestLLVMGPUQueryMMAPass.cpp",
         "Verifiers.cpp",
     ],
@@ -124,7 +122,6 @@
         "ConvertToLLVM.h",
         "KernelConfig.h",
         "Passes.h",
-        "ROCDLKernelConfig.h",
         "ROCDLPasses.h",
     ],
     deps = [
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index e446ee4..148c08e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -67,7 +67,6 @@
     "ConvertToLLVM.h"
     "KernelConfig.h"
     "Passes.h"
-    "ROCDLKernelConfig.h"
     "ROCDLPasses.h"
   SRCS
     "AMDGPUEmulateNarrowType.cpp"
@@ -93,10 +92,8 @@
     "ROCDLAnnotateKernelForTranslation.cpp"
     "ROCDLBufferInstructionsOptimization.cpp"
     "ROCDLConfigureBufferInstructions.cpp"
-    "ROCDLKernelConfig.cpp"
     "ROCDLLowerExecutableTarget.cpp"
     "ROCDLPrefetching.cpp"
-    "ROCDLSelectLoweringStrategy.cpp"
     "TestLLVMGPUQueryMMAPass.cpp"
     "Verifiers.cpp"
   DEPS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index a67547d..1d64876 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1199,27 +1199,6 @@
 // ROCDL Pass Pipelines
 //===----------------------------------------------------------------------===//
 
-static void buildROCDLCodegenConfigurationPassPipelineImpl(
-    OpPassManager &modulePassManager) {
-  {
-    FunctionLikeNest funcPassManager(modulePassManager);
-    funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
-    funcPassManager.addPass(createROCDLConfigureBufferInstructionsPass);
-    addCommonTargetExecutablePreprocessingPasses(funcPassManager);
-  }
-  modulePassManager.addPass(createMaterializeTuningSpecsPass());
-  modulePassManager.addPass(createMaterializeUserConfigsPass());
-
-  modulePassManager.addPass(createROCDLSelectLoweringStrategyPass());
-}
-
-void buildROCDLCodegenConfigurationPassPipeline(
-    OpPassManager &variantPassManager) {
-  variantPassManager.addPass(createSpecializeExportsPass());
-  OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
-  buildROCDLCodegenConfigurationPassPipelineImpl(modulePassManager);
-}
-
 void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
   {
     OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
@@ -1299,13 +1278,6 @@
   // Generated.
   rocdl::registerPasses();
 
-  static PassPipelineRegistration<> ROCDLConfigPipeline(
-      "iree-codegen-rocdl-configuration-pipeline",
-      "Runs pass pipeline to select a suitable lowering strategy for ROCDL",
-      [](OpPassManager &modulePassManager) {
-        buildROCDLCodegenConfigurationPassPipelineImpl(modulePassManager);
-      });
-
   static PassPipelineRegistration<> LinalgROCDLPipeline(
       "iree-codegen-linalg-to-rocdl-pipeline2",
       "Runs pass pipeline to progressively lower Linalg to ROCDL",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
deleted file mode 100644
index 32ad469..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
+++ /dev/null
@@ -1,202 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h"
-
-#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
-
-//===----------------------------------------------------------------------===//
-// Warp Reduction Configuration
-//===----------------------------------------------------------------------===//
-
-static bool isMatvecLike(linalg::LinalgOp linalgOp) {
-  if (linalgOp.getNumParallelLoops() != 2)
-    return false;
-
-  if (linalgOp.getNumReductionLoops() != 1)
-    return false;
-
-  // TODO: Allow for matvec with fused dequantization.
-  FailureOr<linalg::ContractionDimensions> dims =
-      linalg::inferContractionDims(linalgOp);
-  if (failed(dims))
-    return false;
-
-  // TODO: Support batch matvec.
-  if (!dims->batch.empty())
-    return false;
-
-  for (ArrayRef indices : {dims->m, dims->n, dims->k}) {
-    if (!llvm::hasSingleElement(indices))
-      return false;
-  }
-
-  // Check if the first parallel dimension has bound 1, indicating we found a
-  // vector shape.
-  SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
-  if (bounds[dims->m.front()] != 1)
-    return false;
-
-  return true;
-}
-
-//===----------------------------------------------------------------------===//
-// Root Configuration
-//===----------------------------------------------------------------------===//
-
-static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
-                                   mlir::FunctionOpInterface entryPointFn,
-                                   Operation *computeOp) {
-  IREE::GPU::UKernelConfigAttr ukernelConfig = selectUKernel(computeOp);
-  if (succeeded(setDataTiledMultiMmaLoweringConfig(target, entryPointFn,
-                                                   computeOp, ukernelConfig))) {
-    return success();
-  }
-  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(computeOp)) {
-    if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
-                                                     linalgOp))) {
-      return success();
-    }
-    if (succeeded(IREE::GPU::setIGEMMConvolutionLoweringConfig(
-            target, entryPointFn, computeOp))) {
-      return success();
-    }
-    // TODO: Add configurations for matmul here too.
-    if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
-                                                          computeOp))) {
-      return success();
-    }
-  }
-
-  return failure();
-}
-
-// Propagates the configuration to the other ops.
-static void propagateLoweringConfig(Operation *rootOp,
-                                    ArrayRef<Operation *> computeOps) {
-  if (IREE::Codegen::LoweringConfigAttrInterface config =
-          getLoweringConfig(rootOp)) {
-    for (auto op : computeOps) {
-      if (op != rootOp)
-        setLoweringConfig(op, config);
-    }
-  }
-}
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Entry Point
-//===----------------------------------------------------------------------===//
-
-LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) {
-  IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
-  if (!target)
-    return funcOp.emitError("missing GPU target in #hal.executable.target");
-
-  // First check whether we already have workgroup count set--it's a
-  // "contract" to indicate that we should bypass all tiling and
-  // distribution to go down just the most basic lowering flow.
-  if (auto exportOp = getEntryPoint(funcOp)) {
-    if (Block *body = exportOp->getWorkgroupCountBody()) {
-      auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
-      // For scalar dispatch cases--using just one thread of one workgroup.
-      auto isOne = [](Value value) { return matchPattern(value, m_One()); };
-      if (llvm::all_of(retOp.getOperands(), isOne)) {
-        std::array<int64_t, 3> workgroupSize = {1, 1, 1};
-        auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
-            funcOp.getContext(),
-            IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering,
-            workgroupSize);
-        if (failed(setTranslationInfo(funcOp, translationInfo))) {
-          return failure();
-        }
-        return success();
-      }
-    }
-  }
-
-  SmallVector<Operation *> computeOps = getComputeOps(funcOp);
-  if (IREE::Codegen::TranslationInfoAttr translationInfo =
-          getTranslationInfo(funcOp)) {
-    // Currently ROCDL requires propagation of user lowering configs for
-    // all pipelines except TileAndFuse.
-    if (translationInfo.getDispatchLoweringPassPipeline() !=
-        IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
-      for (auto op : computeOps) {
-        if (getLoweringConfig(op)) {
-          propagateLoweringConfig(op, computeOps);
-          break;
-        }
-      }
-    }
-  }
-
-  Operation *rootOp = nullptr;
-
-  // Find the root operation. linalg.generic and linalg.fill are not root
-  // operations if there are other compute operations present.
-  for (Operation *op : llvm::reverse(computeOps)) {
-    if (!isa<linalg::GenericOp, linalg::FillOp>(op)) {
-      rootOp = op;
-      break;
-    }
-    if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
-      // linalg.generic with `reduction` iterator types are roots as well.
-      if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
-        rootOp = op;
-        break;
-      }
-    }
-  }
-
-  if (!rootOp) {
-    for (Operation *op : llvm::reverse(computeOps)) {
-      if (isa<linalg::GenericOp, linalg::FillOp>(op)) {
-        rootOp = op;
-        break;
-      }
-    }
-  }
-
-  if (!rootOp) {
-    // No root operation found, set it to none.
-    auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
-        funcOp.getContext(), CodeGenPipeline::None);
-    if (failed(setTranslationInfo(funcOp, translationInfo))) {
-      return failure();
-    }
-    return success();
-  }
-
-  if (failed(setRootConfig(target, funcOp, rootOp)))
-    return failure();
-
-  if (getTranslationInfo(funcOp).getDispatchLoweringPassPipeline() !=
-      IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
-    propagateLoweringConfig(rootOp, computeOps);
-  }
-  return success();
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h
deleted file mode 100644
index b616ff7..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
-#define IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
-
-#include "mlir/Interfaces/FunctionInterfaces.h"
-
-namespace mlir::iree_compiler {
-
-LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp);
-
-} // namespace mlir::iree_compiler
-
-#endif // IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
index a4f0e67..7b12bb8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
@@ -56,10 +56,4 @@
                 "pass pipeline";
 }
 
-def ROCDLSelectLoweringStrategyPass :
-    Pass<"iree-rocdl-select-lowering-strategy", "ModuleOp"> {
-  let summary = "Select a suitable lowering strategy for an IREE "
-                "hal.executable.variant op";
-}
-
 #endif // IREE_CODEGEN_LLVMGPU_ROCDLPASSES
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp
deleted file mode 100644
index 65c855a..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler {
-
-#define GEN_PASS_DEF_ROCDLSELECTLOWERINGSTRATEGYPASS
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h.inc"
-
-namespace {
-/// Selects a strategy for lowering an IREE hal.executable.variant to ROCDL.
-class ROCDLSelectLoweringStrategyPass final
-    : public impl::ROCDLSelectLoweringStrategyPassBase<
-          ROCDLSelectLoweringStrategyPass> {
-public:
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<IREE::Codegen::IREECodegenDialect, IREE::GPU::IREEGPUDialect,
-                bufferization::BufferizationDialect>();
-  }
-
-  void runOnOperation() override {
-    auto moduleOp = getOperation();
-    for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
-      if (failed(initROCDLLaunchConfig(funcOp))) {
-        funcOp.emitOpError("failed to set configuration");
-        return signalPassFailure();
-      }
-    }
-  }
-};
-} // namespace
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index a2c94ed..d73b766 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -38,7 +38,6 @@
             "config_winograd.mlir",
             "extract_address_computation_gpu.mlir",
             "gpu_pipeline_data_tiling.mlir",
-            "gpu_pipeline_generalize_named_ops.mlir",
             "gpu_pipeline_relayout_ops.mlir",
             "horizontal_fusion_pipeline.mlir",
             "link_executables.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 179e2b9..e8edb39 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -35,7 +35,6 @@
     "elementwise_pipeline.mlir"
     "extract_address_computation_gpu.mlir"
     "gpu_pipeline_data_tiling.mlir"
-    "gpu_pipeline_generalize_named_ops.mlir"
     "gpu_pipeline_relayout_ops.mlir"
     "horizontal_fusion_pipeline.mlir"
     "legalize.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index b3529db..00a0026 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -29,7 +29,6 @@
             "config_vector_distribute_gfx950.mlir",
             "config_user_vector_distribute.mlir",
             "configure_buffer_instructions.mlir",
-            "lowering_scalar_dispatch.mlir",
             "pipeline_elementwise_f8fnuz.mlir",
             "pipeline_elementwise_f8ocp.mlir",
             "pipeline_igemm_tile_and_fuse.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 0655424..f771e64 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -25,7 +25,6 @@
     "config_vector_distribute_gfx950.mlir"
     "config_vector_distribute_reduction_gfx942.mlir"
     "configure_buffer_instructions.mlir"
-    "lowering_scalar_dispatch.mlir"
     "pipeline_elementwise_f8fnuz.mlir"
     "pipeline_elementwise_f8ocp.mlir"
     "pipeline_igemm_tile_and_fuse.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir
deleted file mode 100644
index 0ab27f1..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir
+++ /dev/null
@@ -1,46 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx90a --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-rocdl-select-lowering-strategy, func.func(iree-rocdl-lower-executable-target)))))' -mlir-print-local-scope %s | FileCheck %s
-
-#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer, ReadOnly>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-
-hal.executable @scalar_dispatch {
-  hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
-    hal.executable.export public @scalar_dispatch ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
-      %c1 = arith.constant 1 : index
-      hal.return %c1, %c1, %c1 : index, index, index
-    }
-    builtin.module {
-      func.func @scalar_dispatch() {
-        %c0 = arith.constant 0 : index
-        %c6364136223846793005_i64 = arith.constant 6364136223846793005 : i64
-        %c1442695040888963407_i64 = arith.constant 1442695040888963407 : i64
-        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<i64>>
-        %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
-        %extracted = tensor.extract %2[] : tensor<i64>
-        %3 = arith.muli %extracted, %c6364136223846793005_i64 : i64
-        %4 = arith.addi %3, %c1442695040888963407_i64 : i64
-        %inserted = tensor.insert %4 into %2[] : tensor<i64>
-        iree_tensor_ext.dispatch.tensor.store %inserted, %1, offsets = [], sizes = [], strides = [] : tensor<i64> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<i64>>
-        return
-      }
-    }
-  }
-}
-
-// CHECK-LABEL: func.func @scalar_dispatch()
-//  CHECK-SAME: translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUBaseLowering workgroup_size = [1, 1, 1]>
-//       CHECK:   %[[SPANBIND0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
-//       CHECK:   %[[ASSUMED_SPAN0:.+]] = memref.assume_alignment %[[SPANBIND0]], 64
-//       CHECK:   %[[SPAN0:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_SPAN0]]
-//       CHECK:   %[[SPANBIND1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
-//       CHECK:   %[[ASSUMED_SPAN1:.+]] = memref.assume_alignment %[[SPANBIND1]], 64
-//       CHECK:   %[[SPAN1:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_SPAN1]]
-//       CHECK:   memref.load %[[SPAN0]][] : memref<i64, #amdgpu.address_space<fat_raw_buffer>>
-//       CHECK:   arith.muli {{.+}} : i64
-//       CHECK:   arith.addi {{.+}} : i64
-//       CHECK:   memref.store %{{.+}}, %[[SPAN1]][] : memref<i64, #amdgpu.address_space<fat_raw_buffer>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
deleted file mode 100644
index 986a67e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
+++ /dev/null
@@ -1,40 +0,0 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" --iree-gpu-test-target=gfx942 \
-// RUN:   --split-input-file %s | FileCheck %s
-
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-rocdl-configuration-pipeline)" --iree-gpu-test-target=gfx942 \
-// RUN:   --split-input-file %s | FileCheck %s
-
-// Make sure that the GPU configuration pipelines generalize named ops,
-// e.g., matmul_transpose_b (linalg.matmul indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]) to linalg.generic.
-
-// CHECK:      linalg.fill
-// CHECK-NEXT: linalg.generic
-// CHECK-NOT:  linalg.matmul indexing_maps
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-func.func @warp_reduction_large_vector() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c128 = arith.constant 128 : index
-  %c0 = arith.constant 0 : index
-  %c394240 = arith.constant 394240 : index
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c128) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c394240) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
-  %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>> -> tensor<1x1280xf32>
-  %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>> -> tensor<1280x1280xf32>
-  %5 = tensor.empty() : tensor<1x1280xf32>
-  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
-  %7 = linalg.matmul
-    indexing_maps = [
-      affine_map<(d0, d1, d2) -> (d0, d2)>,
-      affine_map<(d0, d1, d2) -> (d1, d2)>,
-      affine_map<(d0, d1, d2) -> (d0, d1)>
-    ]
-    ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
-  iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : tensor<1x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
-  return
-}