[codegen][rocdl] Remove ROCDLKernelConfig and ROCDLSelectLoweringStrategy (#21820)
This patch removes removes:
- ROCDLKernelConfig.h
- ROCDLKernelConfig.cpp
- ROCDLSelectLoweringStrategy.cpp
The reason for removal is that they are not being used by any pipeline,
and are not being properly tested.
---------
Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 8702fd7..1e8bbca 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -113,10 +113,8 @@
"ROCDLAnnotateKernelForTranslation.cpp",
"ROCDLBufferInstructionsOptimization.cpp",
"ROCDLConfigureBufferInstructions.cpp",
- "ROCDLKernelConfig.cpp",
"ROCDLLowerExecutableTarget.cpp",
"ROCDLPrefetching.cpp",
- "ROCDLSelectLoweringStrategy.cpp",
"TestLLVMGPUQueryMMAPass.cpp",
"Verifiers.cpp",
],
@@ -124,7 +122,6 @@
"ConvertToLLVM.h",
"KernelConfig.h",
"Passes.h",
- "ROCDLKernelConfig.h",
"ROCDLPasses.h",
],
deps = [
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index e446ee4..148c08e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -67,7 +67,6 @@
"ConvertToLLVM.h"
"KernelConfig.h"
"Passes.h"
- "ROCDLKernelConfig.h"
"ROCDLPasses.h"
SRCS
"AMDGPUEmulateNarrowType.cpp"
@@ -93,10 +92,8 @@
"ROCDLAnnotateKernelForTranslation.cpp"
"ROCDLBufferInstructionsOptimization.cpp"
"ROCDLConfigureBufferInstructions.cpp"
- "ROCDLKernelConfig.cpp"
"ROCDLLowerExecutableTarget.cpp"
"ROCDLPrefetching.cpp"
- "ROCDLSelectLoweringStrategy.cpp"
"TestLLVMGPUQueryMMAPass.cpp"
"Verifiers.cpp"
DEPS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index a67547d..1d64876 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1199,27 +1199,6 @@
// ROCDL Pass Pipelines
//===----------------------------------------------------------------------===//
-static void buildROCDLCodegenConfigurationPassPipelineImpl(
- OpPassManager &modulePassManager) {
- {
- FunctionLikeNest funcPassManager(modulePassManager);
- funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
- funcPassManager.addPass(createROCDLConfigureBufferInstructionsPass);
- addCommonTargetExecutablePreprocessingPasses(funcPassManager);
- }
- modulePassManager.addPass(createMaterializeTuningSpecsPass());
- modulePassManager.addPass(createMaterializeUserConfigsPass());
-
- modulePassManager.addPass(createROCDLSelectLoweringStrategyPass());
-}
-
-void buildROCDLCodegenConfigurationPassPipeline(
- OpPassManager &variantPassManager) {
- variantPassManager.addPass(createSpecializeExportsPass());
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- buildROCDLCodegenConfigurationPassPipelineImpl(modulePassManager);
-}
-
void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
{
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
@@ -1299,13 +1278,6 @@
// Generated.
rocdl::registerPasses();
- static PassPipelineRegistration<> ROCDLConfigPipeline(
- "iree-codegen-rocdl-configuration-pipeline",
- "Runs pass pipeline to select a suitable lowering strategy for ROCDL",
- [](OpPassManager &modulePassManager) {
- buildROCDLCodegenConfigurationPassPipelineImpl(modulePassManager);
- });
-
static PassPipelineRegistration<> LinalgROCDLPipeline(
"iree-codegen-linalg-to-rocdl-pipeline2",
"Runs pass pipeline to progressively lower Linalg to ROCDL",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
deleted file mode 100644
index 32ad469..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
+++ /dev/null
@@ -1,202 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h"
-
-#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
-
-//===----------------------------------------------------------------------===//
-// Warp Reduction Configuration
-//===----------------------------------------------------------------------===//
-
-static bool isMatvecLike(linalg::LinalgOp linalgOp) {
- if (linalgOp.getNumParallelLoops() != 2)
- return false;
-
- if (linalgOp.getNumReductionLoops() != 1)
- return false;
-
- // TODO: Allow for matvec with fused dequantization.
- FailureOr<linalg::ContractionDimensions> dims =
- linalg::inferContractionDims(linalgOp);
- if (failed(dims))
- return false;
-
- // TODO: Support batch matvec.
- if (!dims->batch.empty())
- return false;
-
- for (ArrayRef indices : {dims->m, dims->n, dims->k}) {
- if (!llvm::hasSingleElement(indices))
- return false;
- }
-
- // Check if the first parallel dimension has bound 1, indicating we found a
- // vector shape.
- SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
- if (bounds[dims->m.front()] != 1)
- return false;
-
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// Root Configuration
-//===----------------------------------------------------------------------===//
-
-static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
- mlir::FunctionOpInterface entryPointFn,
- Operation *computeOp) {
- IREE::GPU::UKernelConfigAttr ukernelConfig = selectUKernel(computeOp);
- if (succeeded(setDataTiledMultiMmaLoweringConfig(target, entryPointFn,
- computeOp, ukernelConfig))) {
- return success();
- }
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(computeOp)) {
- if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
- linalgOp))) {
- return success();
- }
- if (succeeded(IREE::GPU::setIGEMMConvolutionLoweringConfig(
- target, entryPointFn, computeOp))) {
- return success();
- }
- // TODO: Add configurations for matmul here too.
- if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
- computeOp))) {
- return success();
- }
- }
-
- return failure();
-}
-
-// Propagates the configuration to the other ops.
-static void propagateLoweringConfig(Operation *rootOp,
- ArrayRef<Operation *> computeOps) {
- if (IREE::Codegen::LoweringConfigAttrInterface config =
- getLoweringConfig(rootOp)) {
- for (auto op : computeOps) {
- if (op != rootOp)
- setLoweringConfig(op, config);
- }
- }
-}
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Entry Point
-//===----------------------------------------------------------------------===//
-
-LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) {
- IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
- if (!target)
- return funcOp.emitError("missing GPU target in #hal.executable.target");
-
- // First check whether we already have workgroup count set--it's a
- // "contract" to indicate that we should bypass all tiling and
- // distribution to go down just the most basic lowering flow.
- if (auto exportOp = getEntryPoint(funcOp)) {
- if (Block *body = exportOp->getWorkgroupCountBody()) {
- auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
- // For scalar dispatch cases--using just one thread of one workgroup.
- auto isOne = [](Value value) { return matchPattern(value, m_One()); };
- if (llvm::all_of(retOp.getOperands(), isOne)) {
- std::array<int64_t, 3> workgroupSize = {1, 1, 1};
- auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
- funcOp.getContext(),
- IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering,
- workgroupSize);
- if (failed(setTranslationInfo(funcOp, translationInfo))) {
- return failure();
- }
- return success();
- }
- }
- }
-
- SmallVector<Operation *> computeOps = getComputeOps(funcOp);
- if (IREE::Codegen::TranslationInfoAttr translationInfo =
- getTranslationInfo(funcOp)) {
- // Currently ROCDL requires propagation of user lowering configs for
- // all pipelines except TileAndFuse.
- if (translationInfo.getDispatchLoweringPassPipeline() !=
- IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
- for (auto op : computeOps) {
- if (getLoweringConfig(op)) {
- propagateLoweringConfig(op, computeOps);
- break;
- }
- }
- }
- }
-
- Operation *rootOp = nullptr;
-
- // Find the root operation. linalg.generic and linalg.fill are not root
- // operations if there are other compute operations present.
- for (Operation *op : llvm::reverse(computeOps)) {
- if (!isa<linalg::GenericOp, linalg::FillOp>(op)) {
- rootOp = op;
- break;
- }
- if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
- // linalg.generic with `reduction` iterator types are roots as well.
- if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
- rootOp = op;
- break;
- }
- }
- }
-
- if (!rootOp) {
- for (Operation *op : llvm::reverse(computeOps)) {
- if (isa<linalg::GenericOp, linalg::FillOp>(op)) {
- rootOp = op;
- break;
- }
- }
- }
-
- if (!rootOp) {
- // No root operation found, set it to none.
- auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
- funcOp.getContext(), CodeGenPipeline::None);
- if (failed(setTranslationInfo(funcOp, translationInfo))) {
- return failure();
- }
- return success();
- }
-
- if (failed(setRootConfig(target, funcOp, rootOp)))
- return failure();
-
- if (getTranslationInfo(funcOp).getDispatchLoweringPassPipeline() !=
- IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
- propagateLoweringConfig(rootOp, computeOps);
- }
- return success();
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h
deleted file mode 100644
index b616ff7..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
-#define IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
-
-#include "mlir/Interfaces/FunctionInterfaces.h"
-
-namespace mlir::iree_compiler {
-
-LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp);
-
-} // namespace mlir::iree_compiler
-
-#endif // IREE_COMPILER_CODEGEN_LLVMGPU_ROCDLKERNELCONFIG_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
index a4f0e67..7b12bb8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLPasses.td
@@ -56,10 +56,4 @@
"pass pipeline";
}
-def ROCDLSelectLoweringStrategyPass :
- Pass<"iree-rocdl-select-lowering-strategy", "ModuleOp"> {
- let summary = "Select a suitable lowering strategy for an IREE "
- "hal.executable.variant op";
-}
-
#endif // IREE_CODEGEN_LLVMGPU_ROCDLPASSES
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp
deleted file mode 100644
index 65c855a..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLSelectLoweringStrategy.cpp
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler {
-
-#define GEN_PASS_DEF_ROCDLSELECTLOWERINGSTRATEGYPASS
-#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h.inc"
-
-namespace {
-/// Selects a strategy for lowering an IREE hal.executable.variant to ROCDL.
-class ROCDLSelectLoweringStrategyPass final
- : public impl::ROCDLSelectLoweringStrategyPassBase<
- ROCDLSelectLoweringStrategyPass> {
-public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<IREE::Codegen::IREECodegenDialect, IREE::GPU::IREEGPUDialect,
- bufferization::BufferizationDialect>();
- }
-
- void runOnOperation() override {
- auto moduleOp = getOperation();
- for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
- if (failed(initROCDLLaunchConfig(funcOp))) {
- funcOp.emitOpError("failed to set configuration");
- return signalPassFailure();
- }
- }
- }
-};
-} // namespace
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index a2c94ed..d73b766 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -38,7 +38,6 @@
"config_winograd.mlir",
"extract_address_computation_gpu.mlir",
"gpu_pipeline_data_tiling.mlir",
- "gpu_pipeline_generalize_named_ops.mlir",
"gpu_pipeline_relayout_ops.mlir",
"horizontal_fusion_pipeline.mlir",
"link_executables.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 179e2b9..e8edb39 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -35,7 +35,6 @@
"elementwise_pipeline.mlir"
"extract_address_computation_gpu.mlir"
"gpu_pipeline_data_tiling.mlir"
- "gpu_pipeline_generalize_named_ops.mlir"
"gpu_pipeline_relayout_ops.mlir"
"horizontal_fusion_pipeline.mlir"
"legalize.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index b3529db..00a0026 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -29,7 +29,6 @@
"config_vector_distribute_gfx950.mlir",
"config_user_vector_distribute.mlir",
"configure_buffer_instructions.mlir",
- "lowering_scalar_dispatch.mlir",
"pipeline_elementwise_f8fnuz.mlir",
"pipeline_elementwise_f8ocp.mlir",
"pipeline_igemm_tile_and_fuse.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 0655424..f771e64 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -25,7 +25,6 @@
"config_vector_distribute_gfx950.mlir"
"config_vector_distribute_reduction_gfx942.mlir"
"configure_buffer_instructions.mlir"
- "lowering_scalar_dispatch.mlir"
"pipeline_elementwise_f8fnuz.mlir"
"pipeline_elementwise_f8ocp.mlir"
"pipeline_igemm_tile_and_fuse.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir
deleted file mode 100644
index 0ab27f1..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/lowering_scalar_dispatch.mlir
+++ /dev/null
@@ -1,46 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx90a --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-rocdl-select-lowering-strategy, func.func(iree-rocdl-lower-executable-target)))))' -mlir-print-local-scope %s | FileCheck %s
-
-#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer, ReadOnly>,
- #hal.pipeline.binding<storage_buffer>
-]>
-
-hal.executable @scalar_dispatch {
- hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
- hal.executable.export public @scalar_dispatch ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
- %c1 = arith.constant 1 : index
- hal.return %c1, %c1, %c1 : index, index, index
- }
- builtin.module {
- func.func @scalar_dispatch() {
- %c0 = arith.constant 0 : index
- %c6364136223846793005_i64 = arith.constant 6364136223846793005 : i64
- %c1442695040888963407_i64 = arith.constant 1442695040888963407 : i64
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<i64>>
- %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
- %extracted = tensor.extract %2[] : tensor<i64>
- %3 = arith.muli %extracted, %c6364136223846793005_i64 : i64
- %4 = arith.addi %3, %c1442695040888963407_i64 : i64
- %inserted = tensor.insert %4 into %2[] : tensor<i64>
- iree_tensor_ext.dispatch.tensor.store %inserted, %1, offsets = [], sizes = [], strides = [] : tensor<i64> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<i64>>
- return
- }
- }
- }
-}
-
-// CHECK-LABEL: func.func @scalar_dispatch()
-// CHECK-SAME: translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUBaseLowering workgroup_size = [1, 1, 1]>
-// CHECK: %[[SPANBIND0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
-// CHECK: %[[ASSUMED_SPAN0:.+]] = memref.assume_alignment %[[SPANBIND0]], 64
-// CHECK: %[[SPAN0:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_SPAN0]]
-// CHECK: %[[SPANBIND1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
-// CHECK: %[[ASSUMED_SPAN1:.+]] = memref.assume_alignment %[[SPANBIND1]], 64
-// CHECK: %[[SPAN1:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_SPAN1]]
-// CHECK: memref.load %[[SPAN0]][] : memref<i64, #amdgpu.address_space<fat_raw_buffer>>
-// CHECK: arith.muli {{.+}} : i64
-// CHECK: arith.addi {{.+}} : i64
-// CHECK: memref.store %{{.+}}, %[[SPAN1]][] : memref<i64, #amdgpu.address_space<fat_raw_buffer>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
deleted file mode 100644
index 986a67e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
+++ /dev/null
@@ -1,40 +0,0 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" --iree-gpu-test-target=gfx942 \
-// RUN: --split-input-file %s | FileCheck %s
-
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-rocdl-configuration-pipeline)" --iree-gpu-test-target=gfx942 \
-// RUN: --split-input-file %s | FileCheck %s
-
-// Make sure that the GPU configuration pipelines generalize named ops,
-// e.g., matmul_transpose_b (linalg.matmul indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]) to linalg.generic.
-
-// CHECK: linalg.fill
-// CHECK-NEXT: linalg.generic
-// CHECK-NOT: linalg.matmul indexing_maps
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-func.func @warp_reduction_large_vector() {
- %cst = arith.constant 0.000000e+00 : f32
- %c128 = arith.constant 128 : index
- %c0 = arith.constant 0 : index
- %c394240 = arith.constant 394240 : index
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c128) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c394240) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
- %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>> -> tensor<1x1280xf32>
- %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>> -> tensor<1280x1280xf32>
- %5 = tensor.empty() : tensor<1x1280xf32>
- %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
- %7 = linalg.matmul
- indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
- iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : tensor<1x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
- return
-}