[Codegen][GPU] Add vectorization pattern for iree_gpu.multi_mma (#17453)
This allows vectorizing a statically shaped `iree_gpu.multi_mma` op. To
use this pattern with dynamically shaped tensors, a tiling
implementation for this op will need to be added in the future.
Additionally this adds a new set of `IREEGPUTransformExtensions` because
`Codegen/CommonExtensions` has become a rather large file and this is
specific to the dialect.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 8315cfa..fe2dfda 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -245,6 +245,7 @@
# Dialects
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions:IREEGPUExtensions",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 28c6a9f..018d6a0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -260,6 +260,7 @@
iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Dialect::GPU::TransformExtensions::IREEGPUExtensions
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 22670d2..800af39 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -159,6 +159,10 @@
return ::llvm::cast<::mlir::ShapedType>(getResult().getType());
}
+ bool hasTensorSemantics() {
+ return isa<RankedTensorType>(getResultType());
+ }
+
llvm::SmallVector<::mlir::AffineMap, 4> getIndexingMapsArray() {
return llvm::to_vector<4>(getIndexingMaps().getAsValueRange<::mlir::AffineMapAttr>());
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel
new file mode 100644
index 0000000..cc15ac0
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel
@@ -0,0 +1,69 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "IREEGPUExtensionsOps.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = [
+ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:TransformDialectTdFiles",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "IREEGPUExtensionsOpGen",
+ tbl_outs = [
+ (
+ ["--gen-op-decls"],
+ "IREEGPUExtensionsOps.h.inc",
+ ),
+ (
+ ["--gen-op-defs"],
+ "IREEGPUExtensionsOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "IREEGPUExtensionsOps.td",
+ deps = [
+ ":td_files",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "IREEGPUExtensions",
+ srcs = [
+ "IREEGPUExtensions.cpp",
+ "IREEGPUExtensionsOps.cpp.inc",
+ ],
+ hdrs = [
+ "IREEGPUExtensions.h",
+ "IREEGPUExtensionsOps.h.inc",
+ ],
+ deps = [
+ ":IREEGPUExtensionsOpGen",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectInterfaces",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/CMakeLists.txt
new file mode 100644
index 0000000..a1b62e2
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/CMakeLists.txt
@@ -0,0 +1,43 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_tablegen_library(
+ NAME
+ IREEGPUExtensionsOpGen
+ TD_FILE
+ "IREEGPUExtensionsOps.td"
+ OUTS
+ --gen-op-decls IREEGPUExtensionsOps.h.inc
+ --gen-op-defs IREEGPUExtensionsOps.cpp.inc
+)
+
+iree_cc_library(
+ NAME
+ IREEGPUExtensions
+ HDRS
+ "IREEGPUExtensions.h"
+ "IREEGPUExtensionsOps.h.inc"
+ SRCS
+ "IREEGPUExtensions.cpp"
+ "IREEGPUExtensionsOps.cpp.inc"
+ DEPS
+ ::IREEGPUExtensionsOpGen
+ LLVMSupport
+ MLIRIR
+ MLIRTransformDialect
+ MLIRTransformDialectInterfaces
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
new file mode 100644
index 0000000..4b8fe24
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
@@ -0,0 +1,42 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h"
+
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir::iree_compiler::IREE {
+
+transform_dialect::IREEGPUExtensions::IREEGPUExtensions() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.cpp.inc"
+ >();
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyVectorizeMultiMmaOp
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyVectorizeMultiMmaOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ IREE::GPU::populateIREEGPUVectorizationPatterns(patterns);
+}
+
+} // namespace mlir::iree_compiler::IREE
+
+void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<
+ mlir::iree_compiler::IREE::transform_dialect::IREEGPUExtensions>();
+}
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h
new file mode 100644
index 0000000..b06ef73
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h
@@ -0,0 +1,47 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS_H_
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+class FuncOp;
+} // namespace func
+
+namespace transform {
+// Types needed for builders.
+class TransformTypeInterface;
+} // namespace transform
+
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.h.inc"
+
+namespace mlir::iree_compiler {
+
+/// Registers transformations for the IREE GPU dialect.
+void registerTransformDialectIREEGPUExtension(DialectRegistry ®istry);
+
+namespace IREE::transform_dialect {
+/// Hook to register common transformations to the transform dialect.
+class IREEGPUExtensions
+ : public transform::TransformDialectExtension<IREEGPUExtensions> {
+public:
+ IREEGPUExtensions();
+};
+} // namespace IREE::transform_dialect
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
new file mode 100644
index 0000000..d42f041
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -0,0 +1,29 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def ApplyVectorizeMultiMmaOp : Op<Transform_Dialect,
+ "apply_patterns.iree.vectorize_multi_mma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Populate patterns to vectorize static iree_gpu.multi_mma ops.
+ }];
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
new file mode 100644
index 0000000..73f5490
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
@@ -0,0 +1,30 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Tests for common transforms.
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "vectorize_multi_mma.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
new file mode 100644
index 0000000..6de1d7d
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "vectorize_multi_mma.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_multi_mma.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_multi_mma.mlir
new file mode 100644
index 0000000..e9c3d5c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_multi_mma.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+func.func @tensor_multi_mma(%lhs: tensor<2x3x4xf16>, %rhs: tensor<3x5x4xf16>, %acc: tensor<2x5x4xf32>) -> tensor<2x5x4xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+ } : tensor<2x3x4xf16>, tensor<3x5x4xf16> into tensor<2x5x4xf32>
+ return %0 : tensor<2x5x4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.vectorize_multi_mma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @tensor_multi_mma
+
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[CSTF32:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[LHS:.+]] = vector.transfer_read %arg0[%c0, %c0, %c0], %[[CST]] {{.*}} : tensor<2x3x4xf16>, vector<2x3x4xf16>
+// CHECK-DAG: %[[RHS:.+]] = vector.transfer_read %arg1[%c0, %c0, %c0], %[[CST]] {{.*}} : tensor<3x5x4xf16>, vector<3x5x4xf16>
+// CHECK-DAG: %[[ACC:.+]] = vector.transfer_read %arg2[%c0, %c0, %c0], %[[CSTF32]] {{.*}} : tensor<2x5x4xf32>, vector<2x5x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: : vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
+// CHECK: vector.transfer_write %[[MMA]], %arg2[%c0, %c0, %c0] {{.*}} : vector<2x5x4xf32>, tensor<2x5x4xf32>
+
+// -----
+
+#contraction_accesses = [
+ affine_map<() -> ()>,
+ affine_map<() -> ()>,
+ affine_map<() -> ()>
+]
+func.func @tensor_single_multi_mma(%lhs: tensor<4xf16>, %rhs: tensor<4xf16>, %acc: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [],
+ kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+ } : tensor<4xf16>, tensor<4xf16> into tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.vectorize_multi_mma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @tensor_single_multi_mma
+
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[CSTF32:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[LHS:.+]] = vector.transfer_read %arg0[%c0], %[[CST]] {in_bounds = [true]} : tensor<4xf16>, vector<4xf16>
+// CHECK-DAG: %[[RHS:.+]] = vector.transfer_read %arg1[%c0], %[[CST]] {in_bounds = [true]} : tensor<4xf16>, vector<4xf16>
+// CHECK-DAG: %[[ACC:.+]] = vector.transfer_read %arg2[%c0], %[[CSTF32]] {in_bounds = [true]} : tensor<4xf32>, vector<4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: : vector<4xf16>, vector<4xf16> into vector<4xf32>
+// CHECK: vector.transfer_write %[[MMA]], %arg2[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
new file mode 100644
index 0000000..1780dbb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -0,0 +1,34 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "GPUTransforms",
+ srcs = [
+ "Transforms.cpp",
+ ],
+ hdrs = [
+ "Transforms.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorDialect",
+ "@llvm-project//mlir:VectorUtils",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..07c8f91
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -0,0 +1,33 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ GPUTransforms
+ HDRS
+ "Transforms.h"
+ SRCS
+ "Transforms.cpp"
+ DEPS
+ LLVMSupport
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRTransformUtils
+ MLIRTransforms
+ MLIRVectorDialect
+ MLIRVectorUtils
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
new file mode 100644
index 0000000..d525c0c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -0,0 +1,86 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "iree-codegen-gpu-transforms"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+static LogicalResult vectorizeStaticMultiMmaOp(RewriterBase &rewriter,
+ IREE::GPU::MultiMmaOp mmaOp) {
+ if (!mmaOp.hasTensorSemantics()) {
+ return failure();
+ }
+ if (!mmaOp.getLhsType().hasStaticShape() ||
+ !mmaOp.getRhsType().hasStaticShape() ||
+ !mmaOp.getAccType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(mmaOp,
+ "non-static shape for vectorization");
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(mmaOp);
+
+ Location loc = mmaOp.getLoc();
+
+ // Construct the (never used) zero padding value for each operand.
+ auto lhsPadValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(mmaOp.getLhsType().getElementType()));
+ auto rhsPadValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(mmaOp.getRhsType().getElementType()));
+ Type resultElementType = mmaOp.getResultType().getElementType();
+ auto accPadValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(resultElementType));
+
+ auto lhs = vector::createReadOrMaskedRead(
+ rewriter, loc, mmaOp.getLhs(), mmaOp.getLhsType().getShape(), lhsPadValue,
+ /*useInBoundsInsteadOfMasking=*/true);
+ auto rhs = vector::createReadOrMaskedRead(
+ rewriter, loc, mmaOp.getRhs(), mmaOp.getRhsType().getShape(), rhsPadValue,
+ /*useInBoundsInsteadOfMasking=*/true);
+ auto acc = vector::createReadOrMaskedRead(
+ rewriter, loc, mmaOp.getAcc(), mmaOp.getAccType().getShape(), accPadValue,
+ /*useInBoundsInsteadOfMasking=*/true);
+ auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
+ loc, lhs, rhs, acc, mmaOp.getIndexingMaps(), mmaOp.getIteratorTypes(),
+ mmaOp.getKind());
+
+ // Create the write back to a tensor.
+ int64_t rank = mmaOp.getResultType().getRank();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ mmaOp,
+ /*vector=*/newMmaOp,
+ /*source=*/mmaOp.getAcc(),
+ /*indices=*/SmallVector<Value>(rank, zero),
+ /*inBounds=*/SmallVector<bool>(rank, true));
+ return success();
+}
+
+namespace {
+struct VectorizeStaticMultiMmaOpPattern final
+ : OpRewritePattern<IREE::GPU::MultiMmaOp> {
+ using OpRewritePattern<IREE::GPU::MultiMmaOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp mmaOp,
+ PatternRewriter &rewriter) const override {
+ return vectorizeStaticMultiMmaOp(rewriter, mmaOp);
+ }
+};
+} // namespace
+
+void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns) {
+ patterns.add<VectorizeStaticMultiMmaOpPattern>(patterns.getContext());
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
new file mode 100644
index 0000000..273f23b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -0,0 +1,25 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- Transforms.h - Transformations for the IREE GPU dialect ------------===//
+//
+// Defines transformations that apply to IREE GPU ops for use in multiple
+// places.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_
+#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir::iree_compiler::IREE::GPU
+
+#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
index c263114..d1af546 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
@@ -52,6 +52,7 @@
"@llvm-project//mlir:LinalgTransforms",
# TransformExtensions
"//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions:IREEGPUExtensions",
"//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions",
"//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
"//compiler/src/iree/compiler/Dialect/Flow/TransformExtensions:FlowExtensions",
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index 646c497..4367260 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -45,6 +45,7 @@
MLIRVectorTransformOps
MLIRVectorTransforms
iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
+ iree::compiler::Codegen::Dialect::GPU::TransformExtensions::IREEGPUExtensions
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index 709afd7..e3459ae 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/Interfaces/Interfaces.h"
+#include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
@@ -51,6 +52,7 @@
transform_ext::StructuredTransformOpsExtension>();
registerPartitionableLoopsInterfaceModels(registry);
registerTransformDialectCommonExtension(registry);
+ registerTransformDialectIREEGPUExtension(registry);
registerTransformDialectFlowExtension(registry);
registerTransformDialectLLVMCPUExtension(registry);
registerTransformDialectLLVMGPUExtension(registry);