[Preprocessing] Add a one-off pattern to fuse attention with transpose. (#17901)
The attention ops in SDXL models are usually followed by a
`tensor.expand_shape` and a `transpose`. It is more natural to fold
these in with the attention for codegeneration. This is done as a
one-off pattern for now. Ideally the attention ops can be fused with any
of its elementwise consumers when attention is handled natively by the
backend pass-pipelines.
More details are in https://github.com/iree-org/iree/issues/17673.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 1692c78..a481888 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -33,6 +33,7 @@
"ApplyPDLPatterns.cpp",
"ConvertConv2DToImg2Col.cpp",
"ConvertConvToChannelsLast.cpp",
+ "FoldAttentionWithTranspose.cpp",
"GeneralizeLinalgMatMul.cpp",
"InterpreterPass.cpp",
"MakeSingleDispatchForFunction.cpp",
@@ -54,6 +55,7 @@
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 4613d4b..1bc4c1e 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -29,6 +29,7 @@
"ApplyPDLPatterns.cpp"
"ConvertConv2DToImg2Col.cpp"
"ConvertConvToChannelsLast.cpp"
+ "FoldAttentionWithTranspose.cpp"
"GeneralizeLinalgMatMul.cpp"
"InterpreterPass.cpp"
"MakeSingleDispatchForFunction.cpp"
@@ -65,6 +66,7 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
PUBLIC
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp
new file mode 100644
index 0000000..2c84abb
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp
@@ -0,0 +1,204 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_FOLDATTENTIONWITHTRANSPOSEPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Attention -> Transpose fusion
+//===----------------------------------------------------------------------===//
+
+/// Pattern to fold
+///
+/// ```mlir
+/// %0 = iree_linalg_ext.attention {
+/// indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+/// ins(%query, %key, %value) ....
+/// %1 = tensor.expand_shape %0 into [[0, 1], [2], [3]] ....
+/// %2 = linalg.generic {
+/// indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]}
+/// ins(%1)
+/// ```
+///
+/// to
+///
+/// ```
+/// %0 = iree_linalg_ext.attention {
+/// indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1,
+/// d2)>,
+/// affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3,
+/// d2)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00,
+/// d3, d4)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0,
+/// d1, d00, d4)>]}
+/// ins(%expanded_query, %expanded_key, %expanded_value) ....
+/// ```
+///
+/// Do a very specific match for now. Eventually this can be generalized to a
+/// use similar analysis as to what the reshape propagation across Linalg op
+/// does. TODO(#17673)
+///
+struct FoldAttentionAndTranspose
+ : public OpRewritePattern<IREE::LinalgExt::AttentionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IREE::LinalgExt::AttentionOp attentionOp,
+ PatternRewriter &rewriter) const override {
+ // Check for single use attention op.
+ if (!attentionOp->hasOneUse()) {
+ return rewriter.notifyMatchFailure(attentionOp,
+ "attention op has multiple uses");
+ }
+ auto expandShapeOp =
+ dyn_cast<tensor::ExpandShapeOp>(*attentionOp->user_begin());
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(attentionOp,
+ "user is not an expand shape op.");
+ }
+ // Check for single use of expand shape op.
+ if (!expandShapeOp->hasOneUse()) {
+ return rewriter.notifyMatchFailure(attentionOp,
+ "expand shape op has multiple uses");
+ }
+ auto transposeLikeOp =
+ dyn_cast<linalg::LinalgOp>(*expandShapeOp->user_begin());
+ if (!transposeLikeOp) {
+ return failure();
+ }
+ if (!(transposeLikeOp.getNumDpsInputs() == 1 &&
+ transposeLikeOp.getNumDpsInits() == 1 &&
+ transposeLikeOp.getBlock()
+ ->front()
+ .hasTrait<OpTrait::IsTerminator>() &&
+ transposeLikeOp.getNumLoops() ==
+ transposeLikeOp.getNumParallelLoops())) {
+ return rewriter.notifyMatchFailure(
+ transposeLikeOp, "expand shape user is not a transpose");
+ }
+
+ // Check attention op indexing maps.
+ AffineExpr d0, d1, d2, d3, d4, d5;
+ bindDims(rewriter.getContext(), d0, d1, d2, d3, d4, d5);
+ auto getIndexingMap = [&](int n, ArrayRef<AffineExpr> results) {
+ return AffineMap::get(n, 0, results, rewriter.getContext());
+ };
+ SmallVector<AffineMap> expectedMaps = {
+ getIndexingMap(5, {d0, d1, d2}), getIndexingMap(5, {d0, d3, d2}),
+ getIndexingMap(5, {d0, d3, d4}), getIndexingMap(5, {d0, d1, d4})};
+ if (attentionOp.getIndexingMapsArray() != expectedMaps) {
+ return rewriter.notifyMatchFailure(
+ attentionOp, "mismatch in expected maps, and maps on attention op");
+ }
+
+ // Check reassociation indexing map.
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
+ SmallVector<ReassociationIndices> expectedReassocation = {{0, 1}, {2}, {3}};
+ if (reassociation != expectedReassocation) {
+ return rewriter.notifyMatchFailure(expandShapeOp,
+ "unhandled reassocation");
+ }
+
+ // Check the permutation maps for the transpose.
+ SmallVector<AffineMap> expectedTransposeMaps = {
+ getIndexingMap(4, {d0, d1, d2, d3}),
+ getIndexingMap(4, {d0, d2, d1, d3})};
+ if (transposeLikeOp.getIndexingMapsArray() != expectedTransposeMaps) {
+ return rewriter.notifyMatchFailure(transposeLikeOp,
+ "unhandled transpose op");
+ }
+
+ Location loc = attentionOp.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(transposeLikeOp);
+
+ SmallVector<OpFoldResult> expandedResultShape =
+ tensor::getMixedSizes(rewriter, loc, expandShapeOp);
+ OpFoldResult dim0_split0 = expandedResultShape[0];
+ OpFoldResult dim0_split1 = expandedResultShape[1];
+ OpFoldResult dim1 = expandedResultShape[2];
+ OpFoldResult dim2 =
+ tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 2);
+ OpFoldResult dim3 =
+ tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 1);
+ OpFoldResult dim4 = expandedResultShape[3];
+
+ SmallVector<OpFoldResult> newQuerySizes = {};
+ SmallVector<Value> tmp;
+ SmallVector<int64_t> newQueryShape;
+ dispatchIndexOpFoldResults(newQuerySizes, tmp, newQueryShape);
+
+ auto getReshape = [&](Value v, ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> outputShape) -> Value {
+ SmallVector<int64_t> staticShape;
+ SmallVector<Value> dynamicShape;
+ dispatchIndexOpFoldResults(outputShape, dynamicShape, staticShape);
+ Type resultType = RankedTensorType::get(
+ staticShape, cast<RankedTensorType>(v.getType()).getElementType());
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, resultType, v, reassociation,
+ outputShape)
+ .getResult();
+ };
+
+ Value expandedQuery = getReshape(attentionOp.getQuery(), {{0, 1}, {2}, {3}},
+ {dim0_split0, dim0_split1, dim1, dim2});
+ Value expandedKey = getReshape(attentionOp.getKey(), {{0, 1}, {2}, {3}},
+ {dim0_split0, dim0_split1, dim3, dim2});
+ Value expandedValue = getReshape(attentionOp.getValue(), {{0, 1}, {2}, {3}},
+ {dim0_split0, dim0_split1, dim3, dim4});
+ Value expandedInit = transposeLikeOp.getDpsInitOperand(0)->get();
+
+ SmallVector<AffineMap> newIndexingMaps = {
+ getIndexingMap(6, {d0, d1, d2, d3}),
+ getIndexingMap(6, {d0, d1, d4, d3}),
+ getIndexingMap(6, {d0, d1, d4, d5}),
+ getIndexingMap(6, {d0, d2, d1, d5})};
+ ArrayAttr newIndexingMapsAttr =
+ rewriter.getAffineMapArrayAttr(newIndexingMaps);
+ auto newAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
+ attentionOp.getLoc(), expandedInit.getType(), expandedQuery,
+ expandedKey, expandedValue, attentionOp.getScale(), expandedInit,
+ newIndexingMapsAttr);
+ rewriter.replaceOp(transposeLikeOp, newAttentionOp);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+struct FoldAttentionWithTransposePass
+ : public impl::FoldAttentionWithTransposePassBase<
+ FoldAttentionWithTransposePass> {
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.insert<FoldAttentionAndTranspose>(context);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index e4921b8..edc1705 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -51,6 +51,11 @@
];
}
+def FoldAttentionWithTransposePass :
+ Pass<"iree-preprocessing-fold-attention-with-transpose", ""> {
+ let summary = "Fold attention operation with transpose";
+}
+
def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
let summary = "transform dialect interpreter";
let description = [{
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 54ebb11..e484b95 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -18,6 +18,7 @@
[
"conv2d_to_img2col.mlir",
"conv_to_channels_last.mlir",
+ "fold_attention_with_transpose.mlir",
"generalize_linalg_matmul.mlir",
"make_single_dispatch_for_function.mlir",
"pad_linalg_ops.mlir",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 03c92b7..2c105a2 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"conv2d_to_img2col.mlir"
"conv_to_channels_last.mlir"
+ "fold_attention_with_transpose.mlir"
"generalize_linalg_matmul.mlir"
"make_single_dispatch_for_function.mlir"
"pad_linalg_ops.mlir"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
new file mode 100644
index 0000000..cbb6ebd
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
@@ -0,0 +1,106 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-fold-attention-with-transpose, resolve-shaped-type-result-dims))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
+
+util.func public @fuse_attention_expand_transpose(
+ %arg0: tensor<?x?x?xf16>, %arg1 : tensor<?x?x?xf16>, %arg2 : tensor<?x?x?xf16>, %arg3 : f16) -> tensor<2x?x?x?xf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
+ %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
+ %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
+ %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
+ %empty = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
+ %attention = iree_linalg_ext.attention {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+ ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
+ outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+ %split = arith.divsi %d0, %c2 : index
+ %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
+ : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
+ %empty2 = tensor.empty(%d1, %split, %d4) : tensor<2x?x?x?xf16>
+ %transpose = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%expanded : tensor<2x?x?x?xf16>) outs(%empty2 : tensor<2x?x?x?xf16>) {
+ ^bb0(%b0 : f16, %b1 : f16):
+ linalg.yield %b0 : f16
+ } -> tensor<2x?x?x?xf16>
+ util.return %transpose : tensor<2x?x?x?xf16>
+}
+// CHECK-LABEL: func public @fuse_attention_expand_transpose(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG3:.+]]: f16)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16>
+// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}}
+// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}}
+// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}}
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
+// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: util.return %[[ATTENTION]]
+
+// -----
+
+util.func public @fuse_attention_expand_transpose_static(
+ %arg0 : tensor<20x4096x16xf16>, %arg1 : tensor<20x1024x16xf16>,
+ %arg2 : tensor<20x1024x64xf16>, %arg3 : f16) -> tensor<2x4096x10x64xf16> {
+ %empty = tensor.empty() : tensor<20x4096x64xf16>
+ %attention = iree_linalg_ext.attention {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+ ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
+ outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]]
+ output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
+ %empty2 = tensor.empty() : tensor<2x4096x10x64xf16>
+ %transpose = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%expanded : tensor<2x10x4096x64xf16>) outs(%empty2 : tensor<2x4096x10x64xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<2x4096x10x64xf16>
+ util.return %transpose : tensor<2x4096x10x64xf16>
+}
+// CHECK-LABEL: func public @fuse_attention_expand_transpose_static(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+// CHECK-SAME: %[[ARG3:.+]]: f16)
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
+// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
+// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
+// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
+// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: util.return %[[ATTENTION]]