[Codegen][GPU] Add pass to fuse and hoist scf.forall ops (#17505)

This pass greedily fuses parallel loops together and tries to hoist them
out of serial loops. It is left as TODO to include greedy fusion of
untiled consumers.
diff --git a/compiler/src/iree/compiler/Codegen/BUILD.bazel b/compiler/src/iree/compiler/Codegen/BUILD.bazel
index 31af81b..a07556c 100644
--- a/compiler/src/iree/compiler/Codegen/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/BUILD.bazel
@@ -25,6 +25,7 @@
         "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
         "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
         "//compiler/src/iree/compiler/Codegen/LLVMCPU",
         "//compiler/src/iree/compiler/Codegen/LLVMGPU",
         "//compiler/src/iree/compiler/Codegen/SPIRV",
diff --git a/compiler/src/iree/compiler/Codegen/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/CMakeLists.txt
index ae59fc6..bf8407b 100644
--- a/compiler/src/iree/compiler/Codegen/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/CMakeLists.txt
@@ -23,6 +23,7 @@
     iree::compiler::Codegen::Common::CPU::CommonCPUPasses
     iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+    iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
     iree::compiler::Codegen::LLVMCPU
     iree::compiler::Codegen::LLVMGPU
     iree::compiler::Codegen::SPIRV
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
index b6a7be2..632630c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -4,7 +4,7 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -33,23 +33,46 @@
     ],
 )
 
+iree_gentbl_cc_library(
+    name = "PassesIncGen",
+    tbl_outs = [
+        (
+            ["--gen-pass-decls"],
+            "Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "Passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
+
 iree_compiler_cc_library(
     name = "GPUTransforms",
     srcs = [
+        "FuseAndHoistParallelLoops.cpp",
+        "Passes.cpp",
         "Transforms.cpp",
     ],
     hdrs = [
+        "Passes.h",
+        "Passes.h.inc",
         "Transforms.h",
     ],
     deps = [
+        ":PassesIncGen",
         "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+        "//compiler/src/iree/compiler/Codegen/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:AffineUtils",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FunctionInterfaces",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TensorDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
index 58d9364..588ec3f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -28,21 +28,37 @@
   PUBLIC
 )
 
+iree_tablegen_library(
+  NAME
+    PassesIncGen
+  TD_FILE
+    "Passes.td"
+  OUTS
+    --gen-pass-decls Passes.h.inc
+)
+
 iree_cc_library(
   NAME
     GPUTransforms
   HDRS
+    "Passes.h"
+    "Passes.h.inc"
     "Transforms.h"
   SRCS
+    "FuseAndHoistParallelLoops.cpp"
+    "Passes.cpp"
     "Transforms.cpp"
   DEPS
+    ::PassesIncGen
     LLVMSupport
     MLIRAffineDialect
     MLIRAffineUtils
     MLIRArithDialect
     MLIRFuncDialect
+    MLIRFunctionInterfaces
     MLIRGPUDialect
     MLIRIR
+    MLIRPass
     MLIRSCFDialect
     MLIRSupport
     MLIRTensorDialect
@@ -52,6 +68,7 @@
     MLIRVectorTransforms
     MLIRVectorUtils
     iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+    iree::compiler::Codegen::Transforms
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp
new file mode 100644
index 0000000..46aa775
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp
@@ -0,0 +1,68 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+
+namespace {
+struct FuseAndHoistParallelLoopsPass final
+    : impl::FuseAndHoistParallelLoopsPassBase<FuseAndHoistParallelLoopsPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto sliceParent = sliceOp->getParentOfType<scf::ForallOp>();
+    if (!sliceParent) {
+      return failure();
+    }
+
+    auto producerForall = sliceOp.getSource().getDefiningOp<scf::ForallOp>();
+    if (!producerForall) {
+      return failure();
+    }
+
+    // TODO: Allow extracting multiple uses within the same consumer loop. Still
+    // single producer single consumer loop, but multiple uses within the
+    // consumer.
+    if (!producerForall->hasOneUse()) {
+      return failure();
+    }
+
+    return fuseForallIntoSlice(rewriter, producerForall, sliceParent, sliceOp);
+  }
+};
+
+void FuseAndHoistParallelLoopsPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+
+  // These two patterns are run to a fixed point, allowing fusion within
+  // potentially nested loops, hoisting from said loops, and continued fusion.
+  patterns.add<FuseForalls>(context);
+  populateForallLoopHoistingPattern(patterns);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp
new file mode 100644
index 0000000..4593400
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp
@@ -0,0 +1,23 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir::iree_compiler {
+
+namespace IREE::GPU {
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace
+} // namespace IREE::GPU
+
+void registerIREEGPUPasses() {
+  // Generated.
+  IREE::GPU::registerPasses();
+}
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h
new file mode 100644
index 0000000..e9fa326
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h
@@ -0,0 +1,24 @@
+
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
+
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+#define GEN_PASS_DECL
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" // IWYU pragma: keep
+} // namespace mlir::iree_compiler::IREE::GPU
+
+namespace mlir::iree_compiler {
+/// Register GPU passes.
+void registerIREEGPUPasses();
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
new file mode 100644
index 0000000..25487c1
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
@@ -0,0 +1,21 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES
+#define IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def FuseAndHoistParallelLoopsPass :
+    InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
+  let summary = "Checks GPU specific resource usage constraints like shared memory limits";
+  let dependentDialects = [
+    "::mlir::affine::AffineDialect",
+    "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
+  ];
+}
+
+#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
new file mode 100644
index 0000000..ec698a3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
@@ -0,0 +1,30 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Tests for iree_gpu transforms.
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = enforce_glob(
+        [
+            "fuse_and_hoist_forall.mlir",
+        ],
+        include = ["*.mlir"],
+    ),
+    cfg = "//compiler:lit.cfg.py",
+    tools = [
+        "//tools:iree-opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
new file mode 100644
index 0000000..bf0669c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel   #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "fuse_and_hoist_forall.mlir"
+  TOOLS
+    FileCheck
+    iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
new file mode 100644
index 0000000..f3bd2f7
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
@@ -0,0 +1,70 @@
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 4)>
+#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
+#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+#map4 = affine_map<(d0) -> (d0 * 16)>
+module {
+  func.func @forall_fuse_then_hoist() {
+    %c4 = arith.constant 4 : index
+    %c128 = arith.constant 128 : index
+    %c0 = arith.constant 0 : index
+    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
+    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
+    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+    %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
+    %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
+    %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
+    %6 = tensor.empty() : tensor<128x4xf16>
+    %7 = tensor.empty() : tensor<4x128xf16>
+    %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
+      %9 = scf.forall (%arg2, %arg3) in (64, 1) shared_outs(%arg4 = %6) -> (tensor<128x4xf16>) {
+        %12 = affine.apply #map(%arg2)
+        %13 = affine.apply #map1(%arg3)
+        %14 = affine.apply #map(%arg2)
+        %15 = affine.apply #map2(%arg3)[%arg0]
+        %extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
+        %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
+        %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
+        scf.forall.in_parallel {
+          tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
+        }
+      } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+      %10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
+        %12 = affine.apply #map(%arg2)
+        %13 = affine.apply #map1(%arg3)
+        %14 = affine.apply #map3(%arg2)[%arg0]
+        %15 = affine.apply #map1(%arg3)
+        %extracted_slice = tensor.extract_slice %4[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
+        %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<4x128xf16> to tensor<2x4xf16>
+        %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
+        scf.forall.in_parallel {
+          tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
+        }
+      } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+      %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
+        %12 = affine.apply #map4(%arg2)
+        %13 = affine.apply #map4(%arg3)
+        %extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
+        %extracted_slice_0 = tensor.extract_slice %10[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
+        %extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+        %14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+        scf.forall.in_parallel {
+          tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+        }
+      } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+      scf.yield %11 : tensor<128x128xf32>
+    }
+    flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+    return
+  }
+}
+
+// CHECK-LABEL: func @forall_fuse_then_hoist
+//       CHECK:   %[[OUTER_PARALLEL:.+]] = scf.forall
+//       CHECK:     %[[LOOP:.+]] = scf.for
+//       CHECK:     scf.yield {{.*}} : tensor<16x16xf32>
+//       CHECK:   scf.forall.in_parallel
+//  CHECK-NEXT:     tensor.parallel_insert_slice %[[LOOP]]
+//       CHECK:   flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp
index 07f62f1..2360717 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Passes.cpp
@@ -13,6 +13,7 @@
 #include "iree/compiler/Codegen/Common/CPU/Passes.h"
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
 #include "iree/compiler/Codegen/LLVMCPU/Passes.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
 #include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
@@ -33,6 +34,7 @@
   registerCodegenSPIRVPasses();
   registerCodegenVMVXPasses();
   registerCodegenWGSLPasses();
+  registerIREEGPUPasses();
 }
 
 } // namespace mlir::iree_compiler