[Preprocessing] Add a one-off pattern to fuse attention with transpose. (#17901)

The attention ops in SDXL models are usually followed by a
`tensor.expand_shape` and a `transpose`. It is more natural to fold
these in with the attention for codegeneration. This is done as a
one-off pattern for now. Ideally the attention ops can be fused with any
of its elementwise consumers when attention is handled natively by the
backend pass-pipelines.
More details are in https://github.com/iree-org/iree/issues/17673.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 1692c78..a481888 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -33,6 +33,7 @@
         "ApplyPDLPatterns.cpp",
         "ConvertConv2DToImg2Col.cpp",
         "ConvertConvToChannelsLast.cpp",
+        "FoldAttentionWithTranspose.cpp",
         "GeneralizeLinalgMatMul.cpp",
         "InterpreterPass.cpp",
         "MakeSingleDispatchForFunction.cpp",
@@ -54,6 +55,7 @@
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
+        "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//compiler/src/iree/compiler/Dialect/Stream/IR",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 4613d4b..1bc4c1e 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -29,6 +29,7 @@
     "ApplyPDLPatterns.cpp"
     "ConvertConv2DToImg2Col.cpp"
     "ConvertConvToChannelsLast.cpp"
+    "FoldAttentionWithTranspose.cpp"
     "GeneralizeLinalgMatMul.cpp"
     "InterpreterPass.cpp"
     "MakeSingleDispatchForFunction.cpp"
@@ -65,6 +66,7 @@
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Flow::Transforms
     iree::compiler::Dialect::HAL::IR
+    iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::Stream::IR
     iree::compiler::Dialect::Util::IR
   PUBLIC
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp
new file mode 100644
index 0000000..2c84abb
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp
@@ -0,0 +1,204 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_FOLDATTENTIONWITHTRANSPOSEPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Attention -> Transpose fusion
+//===----------------------------------------------------------------------===//
+
+/// Pattern to fold
+///
+/// ```mlir
+/// %0 = iree_linalg_ext.attention {
+///     indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+///                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+///                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+///                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+///     ins(%query, %key, %value) ....
+/// %1 = tensor.expand_shape %0 into [[0, 1], [2], [3]] ....
+/// %2 = linalg.generic {
+///     indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+///                      affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]}
+///     ins(%1)
+/// ```
+///
+/// to
+///
+/// ```
+/// %0 = iree_linalg_ext.attention {
+///     indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1,
+///     d2)>,
+///                      affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3,
+///                      d2)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00,
+///                      d3, d4)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0,
+///                      d1, d00, d4)>]}
+///     ins(%expanded_query, %expanded_key, %expanded_value) ....
+/// ```
+///
+///  Do a very specific match for now. Eventually this can be generalized to a
+///  use similar analysis as to what the reshape propagation across Linalg op
+///  does. TODO(#17673)
+///
+struct FoldAttentionAndTranspose
+    : public OpRewritePattern<IREE::LinalgExt::AttentionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IREE::LinalgExt::AttentionOp attentionOp,
+                                PatternRewriter &rewriter) const override {
+    // Check for single use attention op.
+    if (!attentionOp->hasOneUse()) {
+      return rewriter.notifyMatchFailure(attentionOp,
+                                         "attention op has multiple uses");
+    }
+    auto expandShapeOp =
+        dyn_cast<tensor::ExpandShapeOp>(*attentionOp->user_begin());
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(attentionOp,
+                                         "user is not an expand shape op.");
+    }
+    // Check for single use of expand shape op.
+    if (!expandShapeOp->hasOneUse()) {
+      return rewriter.notifyMatchFailure(attentionOp,
+                                         "expand shape op has multiple uses");
+    }
+    auto transposeLikeOp =
+        dyn_cast<linalg::LinalgOp>(*expandShapeOp->user_begin());
+    if (!transposeLikeOp) {
+      return failure();
+    }
+    if (!(transposeLikeOp.getNumDpsInputs() == 1 &&
+          transposeLikeOp.getNumDpsInits() == 1 &&
+          transposeLikeOp.getBlock()
+              ->front()
+              .hasTrait<OpTrait::IsTerminator>() &&
+          transposeLikeOp.getNumLoops() ==
+              transposeLikeOp.getNumParallelLoops())) {
+      return rewriter.notifyMatchFailure(
+          transposeLikeOp, "expand shape user is not a transpose");
+    }
+
+    // Check attention op indexing maps.
+    AffineExpr d0, d1, d2, d3, d4, d5;
+    bindDims(rewriter.getContext(), d0, d1, d2, d3, d4, d5);
+    auto getIndexingMap = [&](int n, ArrayRef<AffineExpr> results) {
+      return AffineMap::get(n, 0, results, rewriter.getContext());
+    };
+    SmallVector<AffineMap> expectedMaps = {
+        getIndexingMap(5, {d0, d1, d2}), getIndexingMap(5, {d0, d3, d2}),
+        getIndexingMap(5, {d0, d3, d4}), getIndexingMap(5, {d0, d1, d4})};
+    if (attentionOp.getIndexingMapsArray() != expectedMaps) {
+      return rewriter.notifyMatchFailure(
+          attentionOp, "mismatch in expected maps, and maps on attention op");
+    }
+
+    // Check reassociation indexing map.
+    SmallVector<ReassociationIndices> reassociation =
+        expandShapeOp.getReassociationIndices();
+    SmallVector<ReassociationIndices> expectedReassocation = {{0, 1}, {2}, {3}};
+    if (reassociation != expectedReassocation) {
+      return rewriter.notifyMatchFailure(expandShapeOp,
+                                         "unhandled reassocation");
+    }
+
+    // Check the permutation maps for the transpose.
+    SmallVector<AffineMap> expectedTransposeMaps = {
+        getIndexingMap(4, {d0, d1, d2, d3}),
+        getIndexingMap(4, {d0, d2, d1, d3})};
+    if (transposeLikeOp.getIndexingMapsArray() != expectedTransposeMaps) {
+      return rewriter.notifyMatchFailure(transposeLikeOp,
+                                         "unhandled transpose op");
+    }
+
+    Location loc = attentionOp.getLoc();
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(transposeLikeOp);
+
+    SmallVector<OpFoldResult> expandedResultShape =
+        tensor::getMixedSizes(rewriter, loc, expandShapeOp);
+    OpFoldResult dim0_split0 = expandedResultShape[0];
+    OpFoldResult dim0_split1 = expandedResultShape[1];
+    OpFoldResult dim1 = expandedResultShape[2];
+    OpFoldResult dim2 =
+        tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 2);
+    OpFoldResult dim3 =
+        tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 1);
+    OpFoldResult dim4 = expandedResultShape[3];
+
+    SmallVector<OpFoldResult> newQuerySizes = {};
+    SmallVector<Value> tmp;
+    SmallVector<int64_t> newQueryShape;
+    dispatchIndexOpFoldResults(newQuerySizes, tmp, newQueryShape);
+
+    auto getReshape = [&](Value v, ArrayRef<ReassociationIndices> reassociation,
+                          ArrayRef<OpFoldResult> outputShape) -> Value {
+      SmallVector<int64_t> staticShape;
+      SmallVector<Value> dynamicShape;
+      dispatchIndexOpFoldResults(outputShape, dynamicShape, staticShape);
+      Type resultType = RankedTensorType::get(
+          staticShape, cast<RankedTensorType>(v.getType()).getElementType());
+      return rewriter
+          .create<tensor::ExpandShapeOp>(loc, resultType, v, reassociation,
+                                         outputShape)
+          .getResult();
+    };
+
+    Value expandedQuery = getReshape(attentionOp.getQuery(), {{0, 1}, {2}, {3}},
+                                     {dim0_split0, dim0_split1, dim1, dim2});
+    Value expandedKey = getReshape(attentionOp.getKey(), {{0, 1}, {2}, {3}},
+                                   {dim0_split0, dim0_split1, dim3, dim2});
+    Value expandedValue = getReshape(attentionOp.getValue(), {{0, 1}, {2}, {3}},
+                                     {dim0_split0, dim0_split1, dim3, dim4});
+    Value expandedInit = transposeLikeOp.getDpsInitOperand(0)->get();
+
+    SmallVector<AffineMap> newIndexingMaps = {
+        getIndexingMap(6, {d0, d1, d2, d3}),
+        getIndexingMap(6, {d0, d1, d4, d3}),
+        getIndexingMap(6, {d0, d1, d4, d5}),
+        getIndexingMap(6, {d0, d2, d1, d5})};
+    ArrayAttr newIndexingMapsAttr =
+        rewriter.getAffineMapArrayAttr(newIndexingMaps);
+    auto newAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
+        attentionOp.getLoc(), expandedInit.getType(), expandedQuery,
+        expandedKey, expandedValue, attentionOp.getScale(), expandedInit,
+        newIndexingMapsAttr);
+    rewriter.replaceOp(transposeLikeOp, newAttentionOp);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+struct FoldAttentionWithTransposePass
+    : public impl::FoldAttentionWithTransposePassBase<
+          FoldAttentionWithTransposePass> {
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.insert<FoldAttentionAndTranspose>(context);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index e4921b8..edc1705 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -51,6 +51,11 @@
   ];
 }
 
+def FoldAttentionWithTransposePass :
+    Pass<"iree-preprocessing-fold-attention-with-transpose", ""> {
+  let summary = "Fold attention operation with transpose";
+}
+
 def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
   let summary = "transform dialect interpreter";
   let description = [{
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 54ebb11..e484b95 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -18,6 +18,7 @@
         [
             "conv2d_to_img2col.mlir",
             "conv_to_channels_last.mlir",
+            "fold_attention_with_transpose.mlir",
             "generalize_linalg_matmul.mlir",
             "make_single_dispatch_for_function.mlir",
             "pad_linalg_ops.mlir",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 03c92b7..2c105a2 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "conv2d_to_img2col.mlir"
     "conv_to_channels_last.mlir"
+    "fold_attention_with_transpose.mlir"
     "generalize_linalg_matmul.mlir"
     "make_single_dispatch_for_function.mlir"
     "pad_linalg_ops.mlir"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
new file mode 100644
index 0000000..cbb6ebd
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
@@ -0,0 +1,106 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-fold-attention-with-transpose, resolve-shaped-type-result-dims))" --split-input-file --mlir-print-local-scope  %s | FileCheck %s
+
+util.func public @fuse_attention_expand_transpose(
+  %arg0: tensor<?x?x?xf16>, %arg1 : tensor<?x?x?xf16>, %arg2 : tensor<?x?x?xf16>, %arg3 : f16) -> tensor<2x?x?x?xf16> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
+  %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
+  %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
+  %empty = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
+  %attention = iree_linalg_ext.attention {
+    indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+    ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
+    outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+  %split = arith.divsi %d0, %c2 : index
+  %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
+      : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
+  %empty2 = tensor.empty(%d1, %split, %d4) : tensor<2x?x?x?xf16>
+  %transpose = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%expanded : tensor<2x?x?x?xf16>) outs(%empty2 : tensor<2x?x?x?xf16>) {
+    ^bb0(%b0 : f16, %b1 : f16):
+      linalg.yield %b0 : f16
+  } -> tensor<2x?x?x?xf16>
+  util.return %transpose : tensor<2x?x?x?xf16>
+}
+// CHECK-LABEL: func public @fuse_attention_expand_transpose(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+//  CHECK-SAME:     %[[ARG3:.+]]: f16)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+//   CHECK-DAG:   %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]]
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16>
+//   CHECK-DAG:   %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
+//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+//   CHECK-DAG:   %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}}
+//   CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}}
+//   CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}}
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
+//  CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+//  CHECK-SAME:       outs(%[[EMPTY]] :
+//       CHECK:   util.return %[[ATTENTION]]
+
+// -----
+
+util.func public @fuse_attention_expand_transpose_static(
+      %arg0 : tensor<20x4096x16xf16>, %arg1 : tensor<20x1024x16xf16>,
+      %arg2 : tensor<20x1024x64xf16>, %arg3 : f16) -> tensor<2x4096x10x64xf16> {
+  %empty = tensor.empty() : tensor<20x4096x64xf16>
+  %attention = iree_linalg_ext.attention {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+      ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
+      outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+  %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]]
+      output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
+  %empty2 = tensor.empty() : tensor<2x4096x10x64xf16>
+  %transpose = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%expanded : tensor<2x10x4096x64xf16>) outs(%empty2 : tensor<2x4096x10x64xf16>) {
+    ^bb0(%in: f16, %out: f16):
+      linalg.yield %in : f16
+  } -> tensor<2x4096x10x64xf16>
+  util.return %transpose : tensor<2x4096x10x64xf16>
+}
+// CHECK-LABEL: func public @fuse_attention_expand_transpose_static(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+//  CHECK-SAME:     %[[ARG3:.+]]: f16)
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
+//   CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
+//   CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
+//   CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
+//  CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+//  CHECK-SAME:       outs(%[[EMPTY]] :
+//       CHECK:   util.return %[[ATTENTION]]