Add AMDGPU pattern for chained matmuls (#16398)
This PR adds a pattern for chained matmuls operators like Flash
Attention where by swapping the operands of the sequential matmuls, one
can keep the data in registers and avoid a trip to shared memory and/or
additional shuffle instructions.
The pattern swaps the operands of the contract op, and inserts
transposes for the accumulator and result, if the contracts are chained
and satisfy the MMT
indexing maps.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp
new file mode 100644
index 0000000..bfa778d
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp
@@ -0,0 +1,111 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+struct AMDGPUPrepareForChainedMatmulPass
+ : public AMDGPUPrepareForChainedMatmulBase<
+ AMDGPUPrepareForChainedMatmulPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+
+ /// A chained matmul is one where the result of the first matmul
+ /// is used as the first operand of another matmul (
+ /// first matmul lies in the backward slice of the
+ /// LHS of the second matmul).
+ bool
+ isChainedMatmul(SmallVector<vector::ContractionOp> &chainedMatmuls) const {
+ SetVector<Operation *> backwardSlice;
+ getBackwardSlice(chainedMatmuls[1].getLhs(), &backwardSlice);
+ for (auto *sliceOp : backwardSlice) {
+ auto candidateContract = dyn_cast<vector::ContractionOp>(sliceOp);
+ if (!candidateContract)
+ continue;
+ if (candidateContract == chainedMatmuls[0])
+ return true;
+ }
+ return false;
+ }
+
+ /// Given a vector contract of the form
+ /// %output = vector.contract %lhs, %rhs, %acc
+ /// this function swaps the operands (%rhs, %lhs),
+ /// transposes the accumulator and output and updates
+ /// the indexing maps for the new contract op.
+ void swapOperandsAndTranspose(RewriterBase &rewriter,
+ vector::ContractionOp contractOp) const {
+ Value lhs = contractOp.getLhs();
+ Value rhs = contractOp.getRhs();
+ Value acc = contractOp.getAcc();
+ rewriter.setInsertionPoint(contractOp);
+ Value transposed = rewriter.create<vector::TransposeOp>(
+ contractOp.getLoc(), acc, SmallVector<int64_t>{1, 0});
+ AffineExpr m, n, k;
+ bindDims(rewriter.getContext(), m, n, k);
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ SmallVector<AffineMap> newIndexingMaps = infer({{n, k}, {m, k}, {n, m}});
+ vector::ContractionOp swappedOp = rewriter.create<vector::ContractionOp>(
+ contractOp.getLoc(), rhs, lhs, transposed,
+ rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ contractOp.getIteratorTypesAttr());
+ Value newResult = swappedOp.getResult();
+ transposed = rewriter.create<vector::TransposeOp>(
+ contractOp.getLoc(), newResult, SmallVector<int64_t>{1, 0});
+ rewriter.replaceAllUsesWith(contractOp.getResult(), transposed);
+ }
+
+ /// The only compatible indexing map corresponds to
+ /// the matmul_transpose_b, and is
+ /// (m, n, k) -> (m, k)
+ /// (m, n, k) -> (n, k)
+ /// (m, n, k) -> (m, n)
+ bool isCompatibleIndexingMap(vector::ContractionOp contractOp,
+ MLIRContext *ctx) {
+ AffineExpr m, n, k;
+ bindDims(ctx, m, n, k);
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ SmallVector<AffineMap> newIndexingMaps = infer({{m, k}, {n, k}, {m, n}});
+ return newIndexingMaps == contractOp.getIndexingMapsArray();
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ SmallVector<vector::ContractionOp> chainedMatmuls;
+ funcOp.walk([&](vector::ContractionOp contractOp) {
+ if (!isCompatibleIndexingMap(contractOp, funcOp.getContext()))
+ return WalkResult::advance();
+ chainedMatmuls.push_back(contractOp);
+ return WalkResult::advance();
+ });
+ if (chainedMatmuls.size() != 2)
+ return;
+ if (!isChainedMatmul(chainedMatmuls))
+ return;
+ IRRewriter rewriter(funcOp.getContext());
+ for (vector::ContractionOp op : chainedMatmuls) {
+ swapOperandsAndTranspose(rewriter, op);
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createAMDGPUPrepareForChainedMatmulPass() {
+ return std::make_unique<AMDGPUPrepareForChainedMatmulPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index ff7304c..b34125b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -87,6 +87,7 @@
iree_compiler_cc_library(
name = "LLVMGPU",
srcs = [
+ "AMDGPUChainedMatmulPass.cpp",
"ConvertToLLVM.cpp",
"ConvertToNVVM.cpp",
"ConvertToROCDL.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 783b940..58f884e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -72,6 +72,7 @@
"ROCDLKernelConfig.h"
"ROCDLPasses.h"
SRCS
+ "AMDGPUChainedMatmulPass.cpp"
"ConvertToLLVM.cpp"
"ConvertToNVVM.cpp"
"ConvertToROCDL.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index a12bc99..40dce21 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -130,6 +130,30 @@
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize);
+/// Given a chain of matmuls with some or no operations
+/// in between, like
+///
+/// d = matmul_transpose_b(a, b) + c
+/// ...
+/// e = matmul_transpose_b(d, f) + g
+///
+/// this pattern transforms the above IR to
+///
+/// c.t = transpose c
+/// d = matmul_transpose_b(b, a) + c.t
+/// d.t = transpose d
+/// ...
+/// g.t = transpose g
+/// e = matmul_transpose_b(f, d.t) + g.t
+/// e.t = transpose e
+///
+/// On CDNA architectures, where the layouts of the RHS and result
+/// are the same and transposed from the LHS layout, this type
+/// of transformation can avoid trips to shared memory/shuffle instructions
+/// on operators like Flash Attention.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createAMDGPUPrepareForChainedMatmulPass();
+
//----------------------------------------------------------------------------//
// Register LLVMGPU Passes
//----------------------------------------------------------------------------//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index c96271b..2b3dc7c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -13,6 +13,12 @@
// LLVMGPU Passes (keep alphabetical)
//------------------------------------------------------------------------------
+def AMDGPUPrepareForChainedMatmul :
+ InterfacePass<"iree-amdgpu-prepare-chained-matmul", "mlir::FunctionOpInterface"> {
+ let summary = "Pass to swap operands and transpose accumulator and result";
+ let constructor = "mlir::iree_compiler::createAMDGPUPrepareForChainedMatmulPass()";
+}
+
// TODO: Bring the argument in line with the names used elsewhere.
def ConvertToNVVM :
Pass<"iree-convert-to-nvvm", "ModuleOp"> {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 77ec9c3..69b56cb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -18,6 +18,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "amdgpu_chained_matmul.mlir",
"amdgpu_contraction_distribution.mlir",
"attention.mlir",
"conv_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 9d7e59b..c6e9113 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "amdgpu_chained_matmul.mlir"
"amdgpu_contraction_distribution.mlir"
"attention.mlir"
"cast_address_space_function.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir
new file mode 100644
index 0000000..3735cf2
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir
@@ -0,0 +1,90 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-amdgpu-prepare-chained-matmul),canonicalize,cse)" %s | FileCheck %s
+
+#accesses0 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+#trait0 = {
+ indexing_maps = #accesses0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+builtin.module {
+ // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
+ func.func @chained_matmul(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>,
+ // CHECK: func.func @chained_matmul(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16>
+ // CHECK-SAME: %[[RHS2:.*]]: vector<8x16xf16>, %[[ACC2:.*]]: vector<32x8xf16>
+ %rhs2 : vector<8x16xf16>, %acc2 : vector<32x8xf16>) -> vector<32x8xf16> {
+ // CHECK: %[[TRANS_ACC:.*]] = vector.transpose %[[ACC]], [1, 0] : vector<32x16xf16> to vector<16x32xf16>
+ // CHECK: %[[TRANS_RES:.*]] = vector.contract {indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ // CHECK-SAME: %[[RHS]], %[[LHS]], %[[TRANS_ACC]] : vector<16x8xf16>, vector<32x8xf16> into vector<16x32xf16>
+ // CHECK: %[[RES:.*]] = vector.transpose %[[TRANS_RES]], [1, 0] : vector<16x32xf16> to vector<32x16xf16>
+ %result = vector.contract #trait0 %lhs, %rhs, %acc
+ : vector<32x8xf16>, vector<16x8xf16> into vector<32x16xf16>
+ // CHECK: %[[EXP:.*]] = math.exp2 %[[RES]] : vector<32x16xf16>
+ %exp = math.exp2 %result : vector<32x16xf16>
+ // CHECK: %[[TRANS_ACC2:.*]] = vector.transpose %[[ACC2]], [1, 0] : vector<32x8xf16> to vector<8x32xf16>
+ // CHECK: %[[TRANS_RES2:.*]] = vector.contract {indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ // CHECK-SAME: %[[RHS2]], %[[EXP]], %[[TRANS_ACC2]] : vector<8x16xf16>, vector<32x16xf16> into vector<8x32xf16>
+ // CHECK: %[[RES2:.*]] = vector.transpose %[[TRANS_RES2]], [1, 0] : vector<8x32xf16> to vector<32x8xf16>
+ %result2 = vector.contract #trait0 %exp, %rhs2, %acc2
+ : vector<32x16xf16>, vector<8x16xf16> into vector<32x8xf16>
+ func.return %result2 : vector<32x8xf16>
+ }
+}
+
+// -----
+
+#accesses0 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+#trait0 = {
+ indexing_maps = #accesses0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+builtin.module {
+ func.func @non_chained_matmul(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>
+ // CHECK: func.func @non_chained_matmul(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16>
+ ) -> vector<32x16xf16> {
+ // CHECK-NOT: vector.transpose
+ %result = vector.contract #trait0 %lhs, %rhs, %acc
+ : vector<32x8xf16>, vector<16x8xf16> into vector<32x16xf16>
+ %exp = math.exp2 %result : vector<32x16xf16>
+ func.return %exp : vector<32x16xf16>
+ }
+}
+
+// -----
+
+#accesses0 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+#trait0 = {
+ indexing_maps = #accesses0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+builtin.module {
+ func.func @chained_matmul_second_operand(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>,
+ // CHECK: func.func @chained_matmul_second_operand(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16>
+ %lhs2 : vector<32x16xf16>, %acc2 : vector<32x32xf16>) -> vector<32x32xf16> {
+ // CHECK-NOT: vector.transpose
+ %result = vector.contract #trait0 %lhs, %rhs, %acc
+ : vector<32x8xf16>, vector<16x8xf16> into vector<32x16xf16>
+ %exp = math.exp2 %result : vector<32x16xf16>
+ %result2 = vector.contract #trait0 %lhs2, %exp, %acc2
+ : vector<32x16xf16>, vector<32x16xf16> into vector<32x32xf16>
+ func.return %result2 : vector<32x32xf16>
+ }
+}