[Codegen][GPU] Add pass to fuse and hoist scf.forall ops (#17505)
This pass greedily fuses parallel loops together and tries to hoist them
out of serial loops. It is left as TODO to include greedy fusion of
untiled consumers.
diff --git a/compiler/src/iree/compiler/Codegen/BUILD.bazel b/compiler/src/iree/compiler/Codegen/BUILD.bazel
index 31af81b..a07556c 100644
--- a/compiler/src/iree/compiler/Codegen/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/BUILD.bazel
@@ -25,6 +25,7 @@
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/SPIRV",
diff --git a/compiler/src/iree/compiler/Codegen/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/CMakeLists.txt
index ae59fc6..bf8407b 100644
--- a/compiler/src/iree/compiler/Codegen/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/CMakeLists.txt
@@ -23,6 +23,7 @@
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+ iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::SPIRV
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
index b6a7be2..632630c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
package(
default_visibility = ["//visibility:public"],
@@ -33,23 +33,46 @@
],
)
+iree_gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["--gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
iree_compiler_cc_library(
name = "GPUTransforms",
srcs = [
+ "FuseAndHoistParallelLoops.cpp",
+ "Passes.cpp",
"Transforms.cpp",
],
hdrs = [
+ "Passes.h",
+ "Passes.h.inc",
"Transforms.h",
],
deps = [
+ ":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
index 58d9364..588ec3f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -28,21 +28,37 @@
PUBLIC
)
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ --gen-pass-decls Passes.h.inc
+)
+
iree_cc_library(
NAME
GPUTransforms
HDRS
+ "Passes.h"
+ "Passes.h.inc"
"Transforms.h"
SRCS
+ "FuseAndHoistParallelLoops.cpp"
+ "Passes.cpp"
"Transforms.cpp"
DEPS
+ ::PassesIncGen
LLVMSupport
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRFuncDialect
+ MLIRFunctionInterfaces
MLIRGPUDialect
MLIRIR
+ MLIRPass
MLIRSCFDialect
MLIRSupport
MLIRTensorDialect
@@ -52,6 +68,7 @@
MLIRVectorTransforms
MLIRVectorUtils
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp
new file mode 100644
index 0000000..46aa775
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp
@@ -0,0 +1,68 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+
+namespace {
+struct FuseAndHoistParallelLoopsPass final
+ : impl::FuseAndHoistParallelLoopsPassBase<FuseAndHoistParallelLoopsPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto sliceParent = sliceOp->getParentOfType<scf::ForallOp>();
+ if (!sliceParent) {
+ return failure();
+ }
+
+ auto producerForall = sliceOp.getSource().getDefiningOp<scf::ForallOp>();
+ if (!producerForall) {
+ return failure();
+ }
+
+ // TODO: Allow extracting multiple uses within the same consumer loop. Still
+ // single producer single consumer loop, but multiple uses within the
+ // consumer.
+ if (!producerForall->hasOneUse()) {
+ return failure();
+ }
+
+ return fuseForallIntoSlice(rewriter, producerForall, sliceParent, sliceOp);
+ }
+};
+
+void FuseAndHoistParallelLoopsPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+
+ // These two patterns are run to a fixed point, allowing fusion within
+ // potentially nested loops, hoisting from said loops, and continued fusion.
+ patterns.add<FuseForalls>(context);
+ populateForallLoopHoistingPattern(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp
new file mode 100644
index 0000000..4593400
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp
@@ -0,0 +1,23 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir::iree_compiler {
+
+namespace IREE::GPU {
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace
+} // namespace IREE::GPU
+
+void registerIREEGPUPasses() {
+ // Generated.
+ IREE::GPU::registerPasses();
+}
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h
new file mode 100644
index 0000000..e9fa326
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h
@@ -0,0 +1,24 @@
+
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
+
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+#define GEN_PASS_DECL
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" // IWYU pragma: keep
+} // namespace mlir::iree_compiler::IREE::GPU
+
+namespace mlir::iree_compiler {
+/// Register GPU passes.
+void registerIREEGPUPasses();
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
new file mode 100644
index 0000000..25487c1
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
@@ -0,0 +1,21 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES
+#define IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def FuseAndHoistParallelLoopsPass :
+ InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
+ let summary = "Checks GPU specific resource usage constraints like shared memory limits";
+ let dependentDialects = [
+ "::mlir::affine::AffineDialect",
+ "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
+ ];
+}
+
+#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
new file mode 100644
index 0000000..ec698a3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
@@ -0,0 +1,30 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Tests for iree_gpu transforms.
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "fuse_and_hoist_forall.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
new file mode 100644
index 0000000..bf0669c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "fuse_and_hoist_forall.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
new file mode 100644
index 0000000..f3bd2f7
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
@@ -0,0 +1,70 @@
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 4)>
+#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
+#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+#map4 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @forall_fuse_then_hoist() {
+ %c4 = arith.constant 4 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
+ %6 = tensor.empty() : tensor<128x4xf16>
+ %7 = tensor.empty() : tensor<4x128xf16>
+ %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
+ %9 = scf.forall (%arg2, %arg3) in (64, 1) shared_outs(%arg4 = %6) -> (tensor<128x4xf16>) {
+ %12 = affine.apply #map(%arg2)
+ %13 = affine.apply #map1(%arg3)
+ %14 = affine.apply #map(%arg2)
+ %15 = affine.apply #map2(%arg3)[%arg0]
+ %extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
+ %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
+ %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ %10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
+ %12 = affine.apply #map(%arg2)
+ %13 = affine.apply #map1(%arg3)
+ %14 = affine.apply #map3(%arg2)[%arg0]
+ %15 = affine.apply #map1(%arg3)
+ %extracted_slice = tensor.extract_slice %4[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
+ %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<4x128xf16> to tensor<2x4xf16>
+ %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
+ %12 = affine.apply #map4(%arg2)
+ %13 = affine.apply #map4(%arg3)
+ %extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
+ %extracted_slice_0 = tensor.extract_slice %10[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
+ %extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ scf.yield %11 : tensor<128x128xf32>
+ }
+ flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+ return
+ }
+}
+
+// CHECK-LABEL: func @forall_fuse_then_hoist
+// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall
+// CHECK: %[[LOOP:.+]] = scf.for
+// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
+// CHECK: scf.forall.in_parallel
+// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
+// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp
index 07f62f1..2360717 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Passes.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
@@ -33,6 +34,7 @@
registerCodegenSPIRVPasses();
registerCodegenVMVXPasses();
registerCodegenWGSLPasses();
+ registerIREEGPUPasses();
}
} // namespace mlir::iree_compiler