[LinalgExt] Masked Attention Implementation (#18525)

Enables float/boolean mask as parameters and created linalg generic ops
to apply masking. This image (https://imgur.com/a/1MePgcy) elaborates on
the main files changed and how they enable masked attention:
- Blue boxes represent changed .cpp and .td files to
enable/pass/decompose the mask
  -  Yellow boxes represent the different op classes
- Red boxes represent test mlir files pertaining to certain .cpp/.td
implementations or ops
  
For quick reference, AggregateOpInterfaceImpl.cpp contains the bulk of
the actual mask decomposition (QK += mask)

And for clarification, TileAttention.cpp only holds the
convertToOnlineAttentionOp and getTileAttentionIndexingMaps functions;
TilingInterfaceImpl.cpp contains the main tiling capabilities in the
form of AttentionOp::getTiledImplementation and
OnlineAttentionOp::getTiledImplementation.
  

Updated version of https://github.com/iree-org/iree/pull/18461. This
version was created to include scale affine map and enable fused
attention (incorporated
https://github.com/IanWood1/iree/tree/raikonen/sdpa_mask).
- To that end, many modifications in tests are for adding the scale
affine map (without much functionality change)
- For tiling and decomposition tests, most functionality tests are
included in "tiling.mlir" and "decompose_online_attention.mlir". On the
other hand, the "tile_attention.mlir and "decompose_attention.mlir" are
old paths intended to be be retired and deprecate soon. Hence, no major
tests were added it there.

Test directory for numerical verification:
https://github.com/rohan-tan-bhowmik/iree-masked-attention-test

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index d775c1c..b27193d 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -73,21 +73,22 @@
 };
 } // namespace
 
-static SmallVector<AffineMap>
-getStandardAttentionIndexingMaps(MLIRContext *ctx) {
+static SmallVector<AffineMap> getStandardAttentionIndexingMaps(MLIRContext *ctx,
+                                                               bool hasMask) {
   AffineExpr m, n, k1, k2;
   bindDims(ctx, m, n, k1, k2);
 
-  AffineMap qMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
-  AffineMap kMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
-  AffineMap vMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
-  AffineMap rMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
-
-  return {qMap, kMap, vMap, rMap};
+  auto qMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
+  auto kMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
+  auto vMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
+  auto sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, ctx);
+  auto rMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
+  if (hasMask) {
+    // Add mask map only if it exists
+    auto mMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k2}, ctx);
+    return {qMap, kMap, vMap, sMap, mMap, rMap};
+  }
+  return {qMap, kMap, vMap, sMap, rMap};
 }
 
 struct AttentionOpConversion
@@ -100,6 +101,7 @@
     Value query = op.getQuery();
     Value key = op.getKey();
     Value value = op.getValue();
+    std::optional<Value> optionalMask = op.getAttnMask();
 
     ShapedType outputType = op.getOutputType();
 
@@ -147,10 +149,14 @@
         loc, targetType, rewriter.getFloatAttr(targetType, dk));
 
     // Add batches to standard attention indexing maps.
-    SmallVector<AffineMap> indexingMaps = getStandardAttentionIndexingMaps(ctx);
+    SmallVector<AffineMap> indexingMaps =
+        getStandardAttentionIndexingMaps(ctx, optionalMask.has_value());
+
     int64_t numBatches = op.getQueryType().getRank() - 2;
     for (AffineMap &map : indexingMaps) {
       map = map.shiftDims(numBatches);
+      if (map.getNumResults() == 0)
+        continue;
       for (int batch : llvm::seq<int>(numBatches)) {
         map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
       }
@@ -158,7 +164,7 @@
 
     auto attention = rewriter.create<IREE::LinalgExt::AttentionOp>(
         loc, result.getType(), query, key, value, scale, result,
-        rewriter.getAffineMapArrayAttr(indexingMaps));
+        rewriter.getAffineMapArrayAttr(indexingMaps), optionalMask);
 
     rewriter.replaceOp(op, attention.getResult(0));
     return success();
diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
index 06e85e7..0de566e 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
@@ -8,14 +8,15 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
 
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
-// CHECK:         %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
+// CHECK-SAME:         %[[ARG3:.*]]: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
 // CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
 // CHECK:         return %[[ATTN]] : tensor<5x2x3x4xf32>
 
 // -----
@@ -27,14 +28,15 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
 
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
-// CHECK:         %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
+// CHECK-SAME:         %[[ARG3:.*]]: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
 // CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
 // CHECK:         return %[[ATTN]] : tensor<5x2x8x4xf32>
 
 // -----
@@ -46,14 +48,15 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>,
-// CHECK:         %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
+// CHECK:         %[[ARG3:.*]]: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
 // CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
 // CHECK:         return %[[ATTN]] : tensor<1x3x4xf32>
 
 // -----
@@ -65,6 +68,7 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
 // CHECK-LABEL:         func.func @attention_dyn(
@@ -76,5 +80,5 @@
 // CHECK-DAG:         %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
 // CHECK-DAG:         %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
 // CHECK-DAG:         %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
 // CHECK:         return %[[ATTN]] : tensor<?x?x4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 7598316..4010ae8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -257,6 +257,7 @@
 #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
 #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
 #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
 #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
 #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
 
@@ -272,7 +273,7 @@
   %scale = arith.constant 1.0 : f16
 
   %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR],
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR],
           lowering_config = #config }
         ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
         outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index e67e7af..6533f23 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1747,6 +1747,7 @@
   %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+    affine_map<(d0, d1, d2, d3, d4) -> ()>,
     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
     ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
     outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index 5ebd5c3..227ae1e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -369,6 +369,7 @@
   %8 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
   flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
@@ -407,6 +408,7 @@
   %8 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d1, d2, d3, d4) -> (d1, d2)>,
                      affine_map<(d1, d2, d3, d4) -> (d3, d2)>,
                      affine_map<(d1, d2, d3, d4) -> (d3, d4)>,
+                     affine_map<(d1, d2, d3, d4) -> ()>,
                      affine_map<(d1, d2, d3, d4) -> (d1, d4)>]}
                      ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) -> tensor<1024x512xf16>
   flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<1024x512xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index eb8f4f1..6ebf8ab 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -574,6 +574,7 @@
         %8 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
                      lowering_config = #config}
                      ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
@@ -637,7 +638,7 @@
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
-        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
+        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
         %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
         ^bb0(%in: f16, %out: f16):
           linalg.yield %in : f16
@@ -692,7 +693,7 @@
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
-        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
+        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
         %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
         ^bb0(%in: f16, %out: f16):
           linalg.yield %in : f16
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 649cbd9..3f2c090 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -22,6 +22,7 @@
   %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                    affine_map<(d0, d1, d2, d3, d4) -> ()>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                     ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
   flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
index 13a84ec..109b107 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
@@ -22,6 +22,7 @@
   %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                    affine_map<(d0, d1, d2, d3, d4) -> ()>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                     ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
   flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 596a925..9607c9e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -149,11 +149,7 @@
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == $_op);
-        if(opOperand->getOperandNumber() >= $_op.getNumDpsInputs()){
-          return $_op.getIndexingMapsForResults()[opOperand->getOperandNumber() - $_op.getNumDpsInputs()];
-        }else {
-          return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()];
-        }
+        return getIndexingMapsArray()[opOperand->getOperandNumber()];
       }]
     >,
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 1805597..e08838b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1200,6 +1200,15 @@
 // AttentionOp
 //===----------------------------------------------------------------------===//
 
+void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        TypeRange results, Value query, Value key, Value value,
+                        Value scale, ValueRange outputs, ArrayAttr indexingMaps,
+                        std::optional<Value> mask) {
+  Value maskIn = mask.value_or(Value());
+  build(odsBuilder, odsState, results, query, key, value, scale, maskIn,
+        outputs, indexingMaps);
+}
+
 LogicalResult AttentionOp::verify() {
   AttentionOp attnOp = *this;
 
@@ -1212,6 +1221,9 @@
 
   // Check if indexing maps can represent attention.
   SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
+  if (indexingMaps.size() != getOperation()->getNumOperands()) {
+    return attnOp->emitOpError("expected an indexing map for each operand");
+  }
   FailureOr<AttentionOpDetail> maybeOpInfo =
       AttentionOpDetail::get(indexingMaps);
   if (failed(maybeOpInfo)) {
@@ -1246,8 +1258,8 @@
       }
       if (shape[pos] != valShape[i]) {
         return attnOp->emitError("Shape Mismatch for ")
-               << operandName << ". Expected: " << shape[pos]
-               << " Got: " << valShape[i];
+               << operandName << " at position " << i
+               << ". Expected: " << shape[pos] << " Got: " << valShape[i];
       }
     }
     return success();
@@ -1261,6 +1273,38 @@
     return failure();
   }
 
+  // Additional check case if mask exists
+  if (auto maskMap = getMaskMap()) {
+    if (failed(checkShape("Mask", getMaskType()->getShape(), *maskMap)))
+      return failure();
+  }
+
+  int expectedSymbols = getQueryMap().getNumInputs();
+  auto checkDomain =
+      [&attnOp, &expectedSymbols](StringRef operandName,
+                                  AffineMap indexingMap) -> LogicalResult {
+    if (expectedSymbols != indexingMap.getNumInputs()) {
+      return attnOp->emitError("Mismatched map domain for ")
+             << operandName << ". Expected: " << expectedSymbols
+             << " Got: " << indexingMap.getNumInputs();
+    }
+    return success();
+  };
+
+  if (failed(checkDomain("Query", getQueryMap())) ||
+      failed(checkDomain("Key", getKeyMap())) ||
+      failed(checkDomain("Value", getValueMap())) ||
+      failed(checkDomain("Scale", getScaleMap())) ||
+      failed(checkDomain("Output", getOutputMap()))) {
+    return failure();
+  }
+
+  // Additional check case if mask exists
+  if (auto maskMap = getMaskMap()) {
+    if (failed(checkDomain("Mask", *maskMap)))
+      return failure();
+  }
+
   if (isTiled) {
     // Tiled/Flash attention.
     Type maxElementType = getMaxType()->getElementType();
@@ -1324,20 +1368,29 @@
 
 SmallVector<AffineMap> AttentionOp::getIndexingMapsForOperands() {
   auto maps = getIndexingMapsArray();
-  return SmallVector<AffineMap>(maps.begin(),
-                                maps.begin() + getNumDpsInputs() - 1);
+  maps.resize(getNumDpsInputs());
+  return maps;
 }
 
 SmallVector<AffineMap> AttentionOp::getIndexingMapsForResults() {
   auto maps = getIndexingMapsArray();
-  return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs() - 1,
-                                maps.end());
+  return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs(), maps.end());
 }
 
 //===----------------------------------------------------------------------===//
 // OnlineAttentionOp
 //===----------------------------------------------------------------------===//
 
+void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                              TypeRange results, Value query, Value key,
+                              Value value, Value scale, Value output, Value max,
+                              Value sum, ArrayAttr indexingMaps,
+                              std::optional<Value> mask) {
+  Value maskIn = mask.value_or(Value());
+  build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
+        max, sum, indexingMaps);
+}
+
 LogicalResult OnlineAttentionOp::verify() {
   OnlineAttentionOp attnOp = *this;
 
@@ -1389,11 +1442,46 @@
     return failure();
   }
 
+  // Additional check case if mask exists
+  if (auto maskMap = getMaskMap()) {
+    if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap)))
+      return failure();
+  }
+
+  int expectedSymbols = getQueryMap().getNumInputs();
+  auto checkDomain =
+      [&attnOp, &expectedSymbols](StringRef operandName,
+                                  AffineMap indexingMap) -> LogicalResult {
+    if (expectedSymbols != indexingMap.getNumInputs()) {
+      return attnOp->emitError("Mismatched map domain for ")
+             << operandName << ". Expected: " << expectedSymbols
+             << " Got: " << indexingMap.getNumInputs();
+    }
+    return success();
+  };
+
+  if (failed(checkDomain("Query", getQueryMap())) ||
+      failed(checkDomain("Key", getKeyMap())) ||
+      failed(checkDomain("Value", getValueMap())) ||
+      failed(checkDomain("Scale", getScaleMap())) ||
+      failed(checkDomain("Output", getOutputMap())) ||
+      failed(checkDomain("Max", getMaxMap())) ||
+      failed(checkDomain("Sum", getSumMap()))) {
+    return failure();
+  }
+
+  // Additional check case if mask exists
+  if (auto maskMap = getMaskMap()) {
+    if (failed(checkDomain("Mask", *maskMap)))
+      return failure();
+  }
+
   return success();
 }
 
 MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
-  return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
+  return MutableOperandRange(*this, /*numInputs=*/getMask() ? 5 : 4,
+                             /*numInits=*/3);
 }
 
 LogicalResult OnlineAttentionOp::reifyResultShapes(
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c292452..c642fd0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -461,7 +461,7 @@
        "getLoopIteratorTypes",
        "getResultTilePosition",
        "getTiledImplementation",
-       "generateResultTileValue"]>]> {
+       "generateResultTileValue"]>, AttrSizedOperandSegments]> {
   let summary = "Attention operator";
   let description = [{
     Computes the scaled dot product attention function:
@@ -471,6 +471,10 @@
     Here Q, K, V are given tensors and scale is a scalar value specifying
     the scale to use.
 
+    If an additional mask argument M is included, the result of the first matmul is modified according to:
+
+    Q @ K.T += M
+
     For self-attention, all inputs and the result have the same shape BxNxd
     where B is the batch dimension, N is the sequence length and d is head
     dimension. Typically N >>> d. Usually, this operator also performs
@@ -495,6 +499,7 @@
                        AnyShaped:$key,
                        AnyShaped:$value,
                        AnyFloat:$scale,
+                       Optional<AnyShaped>:$mask,
                        Variadic<AnyShaped>:$outputs,
                        AffineMapArrayAttr:$indexing_maps
   );
@@ -505,11 +510,22 @@
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     attr-dict
-    `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
+    `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )? `)`
     `outs` `(` $outputs `:` type($outputs) `)`
     (`->` type($results)^)?
   }];
 
+  let builders = [
+    OpBuilder<(ins "TypeRange":$results,
+    "Value":$query,
+    "Value":$key,
+    "Value":$value,
+    "Value":$scale,
+    "ValueRange":$outputs,
+    "ArrayAttr":$indexing_maps,
+    CArg<"std::optional<Value>", "std::nullopt">:$mask)>
+  ];
+
   let extraClassDeclaration = [{
     // Method to implement for specifying output range for
     // DestinationStyleOpInterface
@@ -530,9 +546,18 @@
     AffineMap getValueMap() {
       return cast<AffineMap>(getIndexingMapsArray()[2]);
     }
-    AffineMap getOutputMap() {
+    AffineMap getScaleMap() {
       return cast<AffineMap>(getIndexingMapsArray()[3]);
     }
+    std::optional<AffineMap> getMaskMap() {
+      if (getMask()) {
+        return cast<AffineMap>(getIndexingMapsArray()[4]);
+      }
+      return std::nullopt;
+    }
+    AffineMap getOutputMap() {
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
+    }
     int64_t getIterationDomainRank() {
       return getQueryMap().getNumDims();
     }
@@ -545,6 +570,11 @@
     ShapedType getValueType() {
       return cast<ShapedType>(getValue().getType());
     }
+    std::optional<ShapedType> getMaskType() {
+      std::optional<Value> mask = getMask();
+      if (!mask) return std::nullopt;
+      return cast<ShapedType>(mask->getType());
+    }
     FloatType getScaleType() {
       return cast<FloatType>(cast<ShapedType>(getScale().getType()));
     }
@@ -559,12 +589,12 @@
     std::optional<AffineMap> getMaxMap() {
       if (getNumResults() < 2)
         return std::nullopt;
-      return cast<AffineMap>(getIndexingMapsArray()[4]);
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 1]);
     }
     std::optional<AffineMap> getSumMap() {
       if (getNumResults() < 3)
         return std::nullopt;
-      return cast<AffineMap>(getIndexingMapsArray()[5]);
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 2]);
     }
     Value getOutput() {
       return getDpsInitOperand(0)->get();
@@ -599,6 +629,12 @@
     int64_t getValueRank() {
       return getValueType().getRank();
     }
+    std::optional<int64_t> getMaskRank() {
+      std::optional<ShapedType> maskType = getMaskType();
+      if (!maskType)
+        return std::nullopt;
+      return maskType->getRank();
+    }
     int64_t getOutputRank() {
       return getOutputType().getRank();
     }
@@ -650,8 +686,12 @@
     online_attention(Q, K, V, scale, running_max, running_sum)
       = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
 
+    If an additional mask argument M is included, the result of the first matmul is modified according to:
+
+    Q @ K.T += M
+
     The advantage of this online_normalizer is that it can be tiled along
-    it's reduction dimension, making the online_attention operator:
+    its reduction dimension, making the online_attention operator:
       - Tilable along softmax reduction dimension
       - Associative along softmax reduction dimension
       - Commutative along softmax associative dimension
@@ -666,6 +706,7 @@
                        AnyShaped:$key,
                        AnyShaped:$value,
                        AnyFloat:$scale,
+                       Optional<AnyShaped>:$mask,
                        AnyShaped:$output,
                        AnyShaped:$max,
                        AnyShaped:$sum,
@@ -677,11 +718,25 @@
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     attr-dict
-    `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
+    `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)?  `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
     `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
     (`->` type($results)^)?
   }];
 
+  let builders = [
+    OpBuilder<(ins "TypeRange":$results,
+    "Value":$query,
+    "Value":$key,
+    "Value":$value,
+    "Value":$scale,
+    "Value":$output,
+    "Value":$max,
+    "Value":$sum,
+    "ArrayAttr":$indexing_maps,
+    CArg<"std::optional<Value>", "std::nullopt">:$mask)>
+  ];
+
+
   let extraClassDeclaration = [{
     // Method to implement for specifying output range for
     // DestinationStyleOpInterface
@@ -698,14 +753,23 @@
     AffineMap getValueMap() {
       return getIndexingMapsArray()[2];
     }
-    AffineMap getOutputMap() {
+    AffineMap getScaleMap() {
       return getIndexingMapsArray()[3];
     }
+    std::optional<AffineMap> getMaskMap() {
+      if (getMask()) {
+        return cast<AffineMap>(getIndexingMapsArray()[4]);
+      }
+      return std::nullopt;
+    }
+    AffineMap getOutputMap() {
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
+    }
     AffineMap getMaxMap() {
-      return getIndexingMapsArray()[4];
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 1]);
     }
     AffineMap getSumMap() {
-      return getIndexingMapsArray()[5];
+      return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 2]);
     }
 
     int64_t getIterationDomainRank() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 2f8d8ef..b3d5152 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -712,6 +712,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+                    affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
                     ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32>
   return %1 : tensor<6x12x20x8xf32>
@@ -728,6 +729,9 @@
   %1:3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
                        affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> ()>,
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
                        affine_map<(d0, d1, d2, d3) -> (d0)>]}
                        ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
   return %1#0, %1#1, %1#2 : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
@@ -738,10 +742,11 @@
 func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: f32) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
   %scale = arith.constant 1.0 : f32
-  // expected-error @+5 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}}
+  // expected-error @+6 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}}
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+                    affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
                     ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
@@ -749,12 +754,28 @@
 
 // -----
 
-func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
+func.func @attention_missing_affine_map(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
   %scale = arith.constant 1.0 : f32
-  // expected-error @below {{'iree_linalg_ext.attention' op failed to verify op's indexing maps}}
+  // expected-error @below {{'iree_linalg_ext.attention' op expected an indexing map for each operand}}
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+                     ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  return %1 : tensor<192x1024x64xf32>
+}
+
+// -----
+
+func.func @attention_affine_map_domain_mismatch(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
+  %0 = tensor.empty() : tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  // expected-error @below {{Mismatched map domain for Scale. Expected: 5 Got: 4}}
+  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 77ba7aa..0f117cc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1087,6 +1087,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
@@ -1095,6 +1096,7 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1103,7 +1105,7 @@
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
 // CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
 // CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1118,6 +1120,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
@@ -1125,6 +1128,7 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1133,7 +1137,7 @@
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
 // CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
 // CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1150,6 +1154,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
@@ -1157,6 +1162,7 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1165,7 +1171,7 @@
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
 // CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
 // CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1179,6 +1185,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
@@ -1186,6 +1193,7 @@
 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1193,7 +1201,7 @@
 // CHECK-SAME:   {
 // CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
 // CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
 // CHECK-SAME:     tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
 // CHECK-SAME:     tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index e310589..2f2e046 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -100,10 +100,10 @@
         // used by attention's exp2 who's value is always > 0.
         Value mx = builder.create<arith::ConstantOp>(
             loc, builder.getFloatAttr(srcTy, mxDbl));
-        Value clamped = b.create<arith::MinimumFOp>(loc, mx, args[0]);
+        Value clamp = b.create<arith::MinimumFOp>(loc, mx, args[0]);
 
         // Convert scale to the same datatype as input.
-        Value trunc = convertScalarToDtype(b, loc, clamped, dstTy,
+        Value trunc = convertScalarToDtype(b, loc, clamp, dstTy,
                                            /*isUnsignedCast=*/false);
         b.create<linalg::YieldOp>(loc, trunc);
       });
@@ -175,6 +175,53 @@
   return genericOp.getResult(0);
 }
 
+static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
+                       AffineMap maskMap, Value qk, Value mask) {
+
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{qkMap, maskMap});
+  qkMap = compressedMaps[0];
+  maskMap = compressedMaps[1];
+
+  SmallVector<utils::IteratorType> iteratorTypes(qkMap.getNumDims(),
+                                                 utils::IteratorType::parallel);
+
+  Value zero = builder.create<arith::ConstantOp>(
+      loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()), 0.0));
+  Value negInf = builder.create<arith::ConstantOp>(
+      loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()),
+                                -std::numeric_limits<double>::infinity()));
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, qk.getType(), SmallVector<Value>{mask}, qk,
+      SmallVector<AffineMap>{maskMap, qkMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value qkVal = args[1];
+        Value maskVal = args[0];
+
+        // TODO: Replace bool mask condition once treated as i1 (instead of i8)
+        if (maskVal.getType().isInteger()) {
+          maskVal =
+              b.create<arith::TruncIOp>(loc, builder.getI1Type(), maskVal);
+          maskVal = b.create<arith::SelectOp>(loc, maskVal, zero, negInf);
+        } else {
+          maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(),
+                                         /*isUnsignedCast=*/false);
+          // Scaling to compensate for base-2 softmax
+          Value log2e = b.create<arith::ConstantOp>(
+              loc, b.getFloatAttr(qkVal.getType(), M_LOG2E));
+          maskVal = b.create<arith::MulFOp>(loc, maskVal, log2e);
+        }
+        // Finally, set the returned value to the qk element plus the mask
+        // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead
+        // of a SelectFOp to stay consistent with the additive definition of
+        // attention masking)
+        Value add = b.create<arith::AddFOp>(loc, qkVal, maskVal);
+        b.create<linalg::YieldOp>(loc, add);
+      });
+
+  return genericOp.getResult(0);
+}
+
 // Compute output = exp2(output - input)
 static Value computeSubAndExp2(OpBuilder &builder, Location loc,
                                AffineMap inputMap, AffineMap outputMap,
@@ -240,6 +287,7 @@
   Value query = getQuery();
   Value key = getKey();
   Value value = getValue();
+  std::optional<Value> mask = getMask();
   Value oldAcc = getOutput();
   Value oldMax = getMax();
   Value oldSum = getSum();
@@ -265,6 +313,9 @@
   auto qETy = getElementTypeOrSelf(query.getType());
   auto vETy = getElementTypeOrSelf(value.getType());
 
+  AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(),
+                                      /*symbolCount=*/0, getContext());
+
   // In the original algorithm, the scaling is done after the softmax:
   //        softmax(Q @ K.T * scale) @ V
   //
@@ -275,8 +326,6 @@
   // significantly affect numerics.
   if (qETy.getIntOrFloatBitWidth() > 8) {
     AffineMap qMap = getQueryMap();
-    AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
-                                        /*symbolCount=*/0, getContext());
     query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
                                                    query, scale);
   }
@@ -325,6 +374,11 @@
                                                offset);
   }
 
+  // S += mask
+  if (mask != nullptr) {
+    s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value());
+  }
+
   // TODO: This decomposition should be in a seperate op called
   // "online softmax".
   // ---- Online Softmax ----
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index ef5c21c..a01e08a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -8,6 +8,7 @@
 // The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and
 // modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.
 
+#include <optional>
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -106,7 +107,7 @@
   SmallVector<AffineExpr> newExprs;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned pos = cast<AffineDimExpr>(expr).getPosition();
-    SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
+    auto expandedExprs = llvm::to_vector_of<AffineExpr, 6>(
         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
           return builder.getAffineDimExpr(static_cast<unsigned>(v));
         }));
@@ -187,8 +188,7 @@
                       : collapsingReshapeOp.getReassociationMaps(),
           expandedType.getShape(), collapsedType.getShape(), rewriter)))
     return std::nullopt;
-
-  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
+  auto expandedOpIndexingMaps = llvm::to_vector_of<AffineMap, 6>(
       llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) {
         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
       }));
@@ -254,12 +254,18 @@
     }
   }
 
+  Value maskOperand;
+  if (expandedOpOperands.size() > 4) {
+    maskOperand = expandedOpOperands[4];
+  }
+
   // Create a new `AttentionOp` that has the computed operands/indexing maps.
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp = rewriter.create<AttentionOp>(
       attentionOp.getLoc(), resultTypes, expandedOpOperands[0],
       expandedOpOperands[1], expandedOpOperands[2], expandedOpOperands[3],
-      outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps));
+      outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps),
+      maskOperand);
 
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index d68a2eb..8089cd0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -161,6 +161,7 @@
       AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
   AffineMap vMap =
       AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
+  AffineMap sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {}, ctx);
   AffineMap rMap =
       AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
   AffineMap maxMap =
@@ -174,7 +175,7 @@
     vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx);
   }
 
-  SmallVector<AffineMap> attentionMaps = {qMap, kMap,   vMap,
+  SmallVector<AffineMap> attentionMaps = {qMap, kMap,   vMap,  sMap,
                                           rMap, maxMap, sumMap};
   // Add batches to standard attention indexing maps.
   int64_t numBatches = tiledInputRank - 2;
@@ -417,10 +418,14 @@
   SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
   indexingMaps.push_back(maxMap);
   indexingMaps.push_back(sumMap);
+
+  Value mask = attnOp.getMask() ? attnOp.getMask() : Value();
+
   OnlineAttentionOp onlineAttn = rewriter.create<OnlineAttentionOp>(
       loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
       attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
-      accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps));
+      mask, accFill, maxFill, sumFill,
+      rewriter.getAffineMapArrayAttr(indexingMaps));
   onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary());
   ops.push_back(onlineAttn);
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index afa3578..016f9d0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -1887,6 +1887,16 @@
   // Scale
   tiledOperands.emplace_back(scale);
 
+  // Mask
+  Value attnMask = getMask();
+  if (attnMask) {
+    SmallVector<Range> maskSlice =
+        getPermutedSlice(*getMaskMap(), offsets, sizes);
+    Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
+    tiledOperands.emplace_back(maskSliceOp->getResult(0));
+    slices.push_back(maskSliceOp);
+  }
+
   // Output
   {
     Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
@@ -1909,7 +1919,7 @@
     slices.push_back(maxSliceOp);
   }
 
-  std::optional<Value> sum = getMax();
+  std::optional<Value> sum = getSum();
   if (sum) {
     SmallVector<Range> sumSlice =
         getPermutedSlice(*getSumMap(), offsets, sizes);
@@ -1923,12 +1933,13 @@
 
   SmallVector<Type> resultTypes;
   if (hasPureTensorSemantics()) {
-    resultTypes.push_back(tiledOperands[4].getType());
+    int64_t baseIdx = attnMask ? 5 : 4;
+    resultTypes.push_back(tiledOperands[baseIdx].getType());
     if (max) {
-      resultTypes.push_back(tiledOperands[5].getType());
+      resultTypes.push_back(tiledOperands[baseIdx + 1].getType());
     }
     if (sum) {
-      resultTypes.push_back(tiledOperands[6].getType());
+      resultTypes.push_back(tiledOperands[baseIdx + 2].getType());
     }
   }
 
@@ -2024,6 +2035,11 @@
   SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
   SmallVector<Range> valueSlice =
       getPermutedSlice(getValueMap(), offsets, sizes);
+  std::optional<SmallVector<Range>> maskSlice;
+  if (auto maskMap = getMaskMap()) {
+    maskSlice = getPermutedSlice(*maskMap, offsets, sizes);
+  }
+
   SmallVector<Range> outputSlice =
       getPermutedSlice(getOutputMap(), offsets, sizes);
   SmallVector<Range> maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes);
@@ -2065,6 +2081,16 @@
 
   tiledOperands.emplace_back(scale);
 
+  // Mask
+  Value attnMask = getMask();
+  if (attnMask) {
+    SmallVector<Range> maskSlice =
+        getPermutedSlice(*getMaskMap(), offsets, sizes);
+    Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
+    tiledOperands.emplace_back(maskSliceOp->getResult(0));
+    slices.push_back(maskSliceOp);
+  }
+
   /// Output
   {
     Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
@@ -2096,9 +2122,9 @@
   }
 
   SmallVector<Type> resultTypes;
-  resultTypes.push_back(tiledOperands[4].getType());
-  resultTypes.push_back(tiledOperands[5].getType());
-  resultTypes.push_back(tiledOperands[6].getType());
+  resultTypes.push_back(tiledOperands[tiledOperands.size() - 3].getType());
+  resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType());
+  resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType());
 
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
index 7eb4c0a..202ab11 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
@@ -3,14 +3,15 @@
 #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
-#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
 
 func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>)
                      -> tensor<2x10x4096x128xf16> {
   %scale = arith.constant 0.125 : f16
   %acc = tensor.empty() : tensor<2x10x4096x128xf16>
   %out = iree_linalg_ext.attention
-         {indexing_maps = [#map, #map1, #map2, #map3]}
+         {indexing_maps = [#map, #map1, #map2, #map3, #map4]}
          ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16)
          outs(%acc : tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16>
   func.return %out : tensor<2x10x4096x128xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
index 0344eb7..19d6a6b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
@@ -6,6 +6,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
   return %1 : tensor<1x1024x64xf32>
@@ -108,6 +109,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
@@ -213,6 +215,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
@@ -333,6 +336,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index df46f9f..e0aa548 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -3,6 +3,7 @@
 #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
 #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
 #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
 #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
 #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
 
@@ -16,7 +17,7 @@
   %scale = arith.constant 1.0 : f16
 
   %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
         ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
         outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
         -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -82,6 +83,7 @@
 #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
 #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
 #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
 #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
 #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
 
@@ -95,7 +97,7 @@
   %scale = arith.constant 1.0 : f32
 
   %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
         ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
         outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
         -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -165,3 +167,67 @@
 // CHECK:   arith.mulf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %key: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %value: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
+                              %output: tensor<192x1024x64xf32>,
+                              %max: tensor<192x1024xf32>,
+                              %sum: tensor<192x1024xf32>)
+                              -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+// CHECK-LABEL: @attention_f8_masked
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// S = S * scale
+// CHECK:   linalg.generic
+// CHECK:   arith.mulf
+// S = S + mask
+// CHECK:   arith.addf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
index be9c9da..51d3c51 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
@@ -8,6 +8,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
   return %1 : tensor<1x1024x64xf32>
@@ -61,6 +62,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
@@ -117,6 +119,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
@@ -151,6 +154,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 147d592..9bfa8b4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -628,6 +628,7 @@
     transform.yield
   }
 }
+
 // CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
 // CHECK:       func.func @topk_tile_tensor
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
@@ -1537,6 +1538,7 @@
   %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
@@ -1553,6 +1555,7 @@
 // CHECK-DAG:  #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG:  #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG:  #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG:  #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1580,7 +1583,7 @@
 // CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] [%[[D2]],
 // CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
 // CHECK:            %[[D5:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]}
+// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]}
 // CHECK-SAME:                    ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
 // CHECK-SAME:         %[[EXTRACTED_SLICE_1]], %[[C1_F32]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32)
 // CHECK-SAME:         outs(%[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
@@ -1595,11 +1598,152 @@
 
 // -----
 
+func.func @attention_float_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32> {
+  %0 = tensor.empty() : tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+                     ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  return %1 : tensor<192x1024x64xf32>
+}
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)>
+// CHECK-DAG:  #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG:  #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG:  #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG:  #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG:  #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+// CHECK:      func.func @attention_float_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32>
+// CHECK-SAME:   {
+// CHECK-DAG:    %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG:    %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C192:.+]] = arith.constant 192 : index
+// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK-DAG:    %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG:    %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:        %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]]
+// CHECK-SAME:     iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) {
+// CHECK:          %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]]
+// CHECK-SAME:       iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) {
+// CHECK-DAG:        %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]])
+// CHECK-DAG:        %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]])
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME:         1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME:         1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor<?x?x1024xf32>
+// CHECK:            %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK:            %[[D5:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]}
+// CHECK-SAME:                    ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
+// CHECK-SAME:         %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32, tensor<?x?x1024xf32>)
+// CHECK-SAME:         outs(%[[EXTRACTED_SLICE_3]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0]
+// CHECK-SAME:         [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor<?x?x64xf32> into tensor<192x1024x64xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D2]] : tensor<192x1024x64xf32>
+// CHECK:        }
+// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK:      }
+
+// -----
+
+func.func @attention_bool_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32> {
+  %0 = tensor.empty() : tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+                     ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  return %1 : tensor<192x1024x64xf32>
+}
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)>
+// CHECK-DAG:  #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG:  #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG:  #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG:  #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG:  #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+// CHECK:      func.func @attention_bool_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32>
+// CHECK-SAME:   {
+// CHECK-DAG:    %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG:    %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C192:.+]] = arith.constant 192 : index
+// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK-DAG:    %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG:    %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:        %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]]
+// CHECK-SAME:     iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) {
+// CHECK:          %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]]
+// CHECK-SAME:       iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) {
+// CHECK-DAG:        %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]])
+// CHECK-DAG:        %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]])
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME:         1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME:         1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor<?x?x1024xi1>
+// CHECK:            %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK:            %[[D5:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]}
+// CHECK-SAME:                    ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
+// CHECK-SAME:         %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32, tensor<?x?x1024xi1>)
+// CHECK-SAME:         outs(%[[EXTRACTED_SLICE_3]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0]
+// CHECK-SAME:         [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor<?x?x64xf32> into tensor<192x1024x64xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D2]] : tensor<192x1024x64xf32>
+// CHECK:        }
+// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK:      }
+
+// -----
+
 func.func @attention_memref(%query: memref<192x1024x64xf32>, %key: memref<192x1024x64xf32>, %value: memref<192x1024x64xf32>, %output: memref<192x1024x64xf32>) {
   %scale = arith.constant 1.0 : f32
   iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>)
   return
@@ -1616,6 +1760,7 @@
 // CHECK-DAG:  #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG:  #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG:  #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG:  #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 // CHECK:      func.func @attention_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1640,7 +1785,7 @@
 // CHECK:            %[[SUBVIEW_2:.+]] = memref.subview %[[ARG3]][%[[ARG4]], %[[ARG5]], 0] [%[[D0]], %[[D1]], 64] [1, 1,
 // CHECK-SAME:         1] : memref<192x1024x64xf32> to memref<?x?x64xf32, strided<[65536, 64, 1], offset: ?>>
 // CHECK:            iree_linalg_ext.attention
-// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]}
+// CHECK-SAME:                     {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]}
 // CHECK-SAME:         ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]], %[[C1_F32]] : memref<?x?x64xf32,
 // CHECK-SAME:         strided<[65536, 64, 1], offset: ?>>, memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>,
 // CHECK-SAME:         memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>, f32) outs(%[[SUBVIEW_2]] :
@@ -1662,6 +1807,7 @@
       indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
       ins(%query, %key, %value, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16)
       outs(%0 : tensor<2x10x4096x64xf16>) -> tensor<2x10x4096x64xf16>
@@ -1706,6 +1852,7 @@
 #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
 #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
 #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
 #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
 #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
 
@@ -1723,7 +1870,7 @@
   %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
 
   %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
         ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32)
         outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
         -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -1737,8 +1884,9 @@
 // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
 // CHECK-LABEL: @online_attention
 // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
 // CHECK-DAG:   %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
@@ -1751,7 +1899,7 @@
 // CHECK-DAG:  %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
 // CHECK-DAG:  %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
 // CHECK-DAG: iree_linalg_ext.online_attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP5]]]}
 // CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32)
 // CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
 // CHECK: scf.forall.in_parallel
@@ -1763,3 +1911,150 @@
     transform.yield
   }
 }
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_float_mask(%query: tensor<192x1024x64xf32>,
+                            %key: tensor<192x1024x64xf32>,
+                            %value: tensor<192x1024x64xf32>,
+                            %mask: tensor<192x1024x1024xf32>)
+                            -> tensor<192x1024x64xf32> {
+  %scale = arith.constant 1.0 : f32
+
+  %output_empty = tensor.empty() : tensor<192x1024x64xf32>
+  %row_red_empty = tensor.empty() : tensor<192x1024xf32>
+
+  %sum_ident = arith.constant 0.000000e+00 : f32
+  %max_ident = arith.constant -3.40282347E+38 : f32
+
+  %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+  %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+
+  // Adjust the operation to correctly handle the mask
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>)
+        outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0 : tensor<192x1024x64xf32>
+}
+
+// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-LABEL: @online_attention_float_mask
+// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
+// CHECK-DAG:   %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
+// CHECK-DAG:   %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
+// CHECK-DAG:   %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
+// CHECK-DAG:  %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
+// CHECK-DAG:  %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
+// CHECK-DAG:  %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
+// CHECK-DAG:  %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor<4x128x1024xf32>
+// CHECK-DAG:  %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
+// CHECK-DAG:  %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG:  %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: iree_linalg_ext.online_attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]}
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xf32>)
+// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
+// CHECK: scf.forall.in_parallel
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_bool_mask(%query: tensor<192x1024x64xf32>,
+                            %key: tensor<192x1024x64xf32>,
+                            %value: tensor<192x1024x64xf32>,
+                            %mask: tensor<192x1024x1024xi1>)
+                            -> tensor<192x1024x64xf32> {
+  %scale = arith.constant 1.0 : f32
+
+  %output_empty = tensor.empty() : tensor<192x1024x64xf32>
+  %row_red_empty = tensor.empty() : tensor<192x1024xf32>
+
+  %sum_ident = arith.constant 0.000000e+00 : f32
+  %max_ident = arith.constant -3.40282347E+38 : f32
+
+  %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+  %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+
+  // Adjust the operation to correctly handle the mask
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>)
+        outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0 : tensor<192x1024x64xf32>
+}
+
+
+// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-LABEL: @online_attention_bool_mask
+// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
+// CHECK-DAG:   %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
+// CHECK-DAG:   %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
+// CHECK-DAG:   %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
+// CHECK-DAG:  %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
+// CHECK-DAG:  %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
+// CHECK-DAG:  %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
+// CHECK-DAG:  %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor<4x128x1024xi1>
+// CHECK-DAG:  %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
+// CHECK-DAG:  %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG:  %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: iree_linalg_ext.online_attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]}
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xi1>)
+// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
+// CHECK: scf.forall.in_parallel
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
index 428f240..90797a0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "llvm/ADT/SetOperations.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
@@ -83,10 +84,6 @@
 
 FailureOr<AttentionOpDetail>
 AttentionOpDetail::get(ArrayRef<AffineMap> indexingMaps) {
-  if (indexingMaps.size() != 4 && indexingMaps.size() != 6) {
-    return failure();
-  }
-
   AttentionOpDetail opInfo;
   opInfo.inferFromIndexingMaps(indexingMaps);
   opInfo.maps = SmallVector<AffineMap>(indexingMaps);
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 6b02036..e0e0cef 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -3,11 +3,12 @@
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> {
   %0 = tensor.empty() : tensor<20x4096x64xf16>
-  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
   %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
   util.return %expanded : tensor<2x10x4096x64xf16>
 }
@@ -36,11 +37,51 @@
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_static_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x4096x64xf16> {
+  %0 = tensor.empty() : tensor<20x4096x64xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+  %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
+  util.return %expanded : tensor<2x10x4096x64xf16>
+}
+
+//CHECK-LABEL: func public @attention_static_masked(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+// CHECK-SAME:     %[[ARG3:.+]]: f16
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>)
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16>
+//  CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
+//  CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
+//  CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+//  CHECK-DAG:   %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 1024]
+//      CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:       indexing_maps =
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+// CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   util.return %[[ATTENTION]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> {
   %0 = tensor.empty() : tensor<20x4096x64xf16>
-  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
   %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16>
   util.return %expanded : tensor<2x10x2048x2x2x32xf16>
 }
@@ -66,11 +107,50 @@
 
 // -----
 
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_expand_all_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x2048x2x2x32xf16> {
+  %0 = tensor.empty() : tensor<20x4096x64xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4: tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+  %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16>
+  util.return %expanded : tensor<2x10x2048x2x2x32xf16>
+}
+
+//CHECK-LABEL: func public @attention_expand_all_masked(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+// CHECK-SAME:     %[[ARG3:.+]]: f16
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>)
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16>
+//  CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16>
+//  CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16>
+//  CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16>
+//  CHECK-DAG:   %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 1024] : tensor<20x4096x1024xf16> into tensor<2x10x2048x2x1024xf16>
+//      CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)>
+// CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   util.return %[[ATTENTION]]
+
+// -----
 
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16) -> tensor<2x?x?x?xf16> {
   %c0 = arith.constant 0 : index
@@ -83,7 +163,7 @@
   %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
   %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
   %0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
-  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
   %split = arith.divsi %d0, %c2 : index
   %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
       : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
@@ -130,7 +210,78 @@
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>) -> tensor<2x?x?x?xf16> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
+  %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
+  %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
+  %0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
+  %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16, tensor<?x?x?xf16>) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+  %split = arith.divsi %d0, %c2 : index
+  %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
+      : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
+  util.return %expanded : tensor<2x?x?x?xf16>
+}
+
+//CHECK-LABEL: func public @attention_dynamic_masked(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME:     %[[ARG3:.+]]: f16
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+//  CHECK-DAG:   %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+//  CHECK-DAG:   %[[SPLIT0:.+]] = arith.divui %[[D0]]
+//  CHECK-DAG:   %[[VAL:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
+//  CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
+//  CHECK-DAG:   %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+//  CHECK-DAG:   %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+//  CHECK-DAG:   %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]]
+//  CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
+//  CHECK-DAG:   %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+//  CHECK-DAG:   %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+//  CHECK-DAG:   %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]]
+//  CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
+//  CHECK-DAG:   %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]]
+//  CHECK-DAG:   %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]]
+//  CHECK-DAG:   %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]]
+//  CHECK-DAG:   %[[SPLIT3:.+]] = arith.divui %[[D10]], %[[C2]]
+//  CHECK-DAG:   %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]]
+//      CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   util.return %[[ATTENTION]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
   %13 = tensor.empty() : tensor<4x32x64x128xf16>
@@ -138,7 +289,7 @@
   %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
   %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
   %17 = tensor.empty() : tensor<128x64x128xf16>
-  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
   util.return %18 : tensor<128x64x128xf16>
 }
 
@@ -162,13 +313,52 @@
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @sink_through_attention_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16, %3 : tensor<4x32x64x64xf16>) -> (tensor<128x64x128xf16>) {
+  %13 = tensor.empty() : tensor<4x32x64x128xf16>
+  %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+  %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+  %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+  %collapsed_15 = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<4x32x64x64xf16> into tensor<128x64x64xf16>
+  %17 = tensor.empty() : tensor<128x64x128xf16>
+  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst, %collapsed_15 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+  util.return %18 : tensor<128x64x128xf16>
+}
+
+// CHECK-LABEL: util.func public @sink_through_attention_masked
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG3:.+]]: f16
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]:
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]] :
+//       CHECK:   %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+//       CHECK:   util.return %[[RET]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
 util.func public @sink_single_collapse(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
   %13 = tensor.empty() : tensor<4x32x64x128xf16>
   %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
   %17 = tensor.empty() : tensor<128x64x128xf16>
-  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
   util.return %18 : tensor<128x64x128xf16>
 }
 
@@ -188,3 +378,41 @@
 //  CHECK-SAME:       ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]] :
 //       CHECK:   %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
 //       CHECK:   util.return %[[RET]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16, %3 : tensor<128x64x64xf16>) -> (tensor<128x64x128xf16>) {
+  %13 = tensor.empty() : tensor<4x32x64x128xf16>
+  %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+  %17 = tensor.empty() : tensor<128x64x128xf16>
+  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %1, %2, %cst, %3 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+  util.return %18 : tensor<128x64x128xf16>
+}
+
+// CHECK-LABEL: util.func public @sink_single_collapse_masked
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG3:.+]]: f16
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[EXPANDED1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128]
+//   CHECK-DAG:   %[[EXPANDED2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128]
+//   CHECK-DAG:   %[[EXPANDED3:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 64]
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+//  CHECK-SAME:       ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] :
+//       CHECK:   %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+//       CHECK:   util.return %[[RET]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
index 7f0ed69..edfbfcf 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
@@ -388,7 +388,7 @@
   } -> tensor<4x1x4x16x32x128xf16>
   %3 = tensor.empty() : tensor<4x1x32x1x128xf16>
   %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
-    %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+    %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
     flow.return %5 : tensor<4x1x32x1x128xf16>
   }
   %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
@@ -402,3 +402,98 @@
 //       CHECK:      %[[ATTENTION:.+]] = iree_linalg_ext.attention
 //       CHECK:        ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
 //       CHECK:      flow.return %[[ATTENTION]]
+
+// -----
+
+// Clone 'gather-like' operations
+util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>) -> tensor<4x32x1x128xf16> {
+  %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+  ^bb0(%in: i64, %out: f16):
+    %5 = arith.index_cast %in : i64 to index
+    %6 = linalg.index 3 : index
+    %7 = linalg.index 4 : index
+    %8 = linalg.index 5 : index
+    %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16>
+    linalg.yield %extracted : f16
+  } -> tensor<4x1x4x16x32x128xf16>
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+  ^bb0(%in: i64, %out: f16):
+    %5 = arith.addi %in, %arg3 : i64
+    %6 = arith.index_cast %5 : i64 to index
+    %7 = linalg.index 3 : index
+    %8 = linalg.index 4 : index
+    %9 = linalg.index 5 : index
+    %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16>
+    linalg.yield %extracted : f16
+  } -> tensor<4x1x4x16x32x128xf16>
+  %3 = tensor.empty() : tensor<4x1x32x1x128xf16>
+  %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
+    %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+    flow.return %5 : tensor<4x1x32x1x128xf16>
+  }
+  %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
+  util.return %collapsed : tensor<4x32x1x128xf16>
+}
+
+// CHECK-LABEL:  util.func public @clone_gather_lik
+//       CHECK:    %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:      %[[GATHER0:.+]] = linalg.generic
+//       CHECK:      %[[GATHER1:.+]] = linalg.generic
+//       CHECK:      %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//       CHECK:        ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
+//       CHECK:      flow.return %[[ATTENTION]]
+
+// -----
+
+util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>, %arg6: tensor<4x1x32x128xf16>) -> tensor<4x32x1x128xf16> {
+  %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+  ^bb0(%in: i64, %out: f16):
+    %5 = arith.index_cast %in : i64 to index
+    %6 = linalg.index 3 : index
+    %7 = linalg.index 4 : index
+    %8 = linalg.index 5 : index
+    %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16>
+    linalg.yield %extracted : f16
+  } -> tensor<4x1x4x16x32x128xf16>
+
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+  ^bb0(%in: i64, %out: f16):
+    %5 = arith.addi %in, %arg3 : i64
+    %6 = arith.index_cast %5 : i64 to index
+    %7 = linalg.index 3 : index
+    %8 = linalg.index 4 : index
+    %9 = linalg.index 5 : index
+    %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16>
+    linalg.yield %extracted : f16
+  } -> tensor<4x1x4x16x32x128xf16>
+
+  %3 = tensor.empty() : tensor<4x1x32x1x128xf16>
+
+  %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
+    %5 = iree_linalg_ext.attention {
+      indexing_maps = [
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>,
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>,
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>,
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4)>,
+        affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+      ]
+    } ins(%arg5, %1, %2, %arg2, %arg6 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16, tensor<4x1x32x128xf16>) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+
+    flow.return %5 : tensor<4x1x32x1x128xf16>
+  }
+
+  %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
+  util.return %collapsed : tensor<4x32x1x128xf16>
+}
+
+// CHECK-LABEL:  util.func public @clone_gather_lik
+//       CHECK:    %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:      %[[GATHER0:.+]] = linalg.generic
+//       CHECK:      %[[GATHER1:.+]] = linalg.generic
+//       CHECK:      %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//       CHECK:        ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
+//       CHECK:      flow.return %[[ATTENTION]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
index 993b690..1d10dfc 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
@@ -97,7 +97,7 @@
     linalg.yield %5 : f16
   } -> tensor<?x?x?xf16>
 
-  %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+  %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
 
   %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
   ^bb0(%in: f16, %out: f16):
@@ -123,3 +123,49 @@
 //       CHECK:         %[[GEN2:.+]] = linalg.generic
 //  CHECK-SAME:           ins(%[[ATTN]]
 //       CHECK:         flow.return %[[GEN2]]
+
+// -----
+
+util.func public @attention_dispatch_masked(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>, %arg5: tensor<?x?x?xf16>, %arg6: tensor<?x?x?xf16>, %arg7: tensor<?x?x?xf16>) -> tensor<?x?x?xf16> {
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5 = arith.mulf %in, %in : f16
+    linalg.yield %5 : f16
+  } -> tensor<?x?x?xf16>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5 = arith.mulf %in, %in : f16
+    linalg.yield %5 : f16
+  } -> tensor<?x?x?xf16>
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5 = arith.mulf %in, %in : f16
+    linalg.yield %5 : f16
+  } -> tensor<?x?x?xf16>
+
+  %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3, %arg4: tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16, tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+
+  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5 = arith.mulf %in, %in : f16
+    linalg.yield %5 : f16
+  } -> tensor<?x?x?xf16>
+  util.return %4 : tensor<?x?x?xf16>
+}
+
+// CHECK-LABEL:     util.func public @attention_dispatch_masked
+//       CHECK:       %[[DISPATCH0:.+]] = flow.dispatch.region
+//  CHECK-NEXT:         %[[GEN0:.+]] = linalg.generic
+//       CHECK:         flow.return %[[GEN0]]
+//       CHECK:       %[[DISPATCH1:.+]] = flow.dispatch.region
+//  CHECK-NEXT:         %[[GEN1:.+]] = linalg.generic
+//       CHECK:         flow.return %[[GEN1]]
+//       CHECK:       %[[DISPATCH2:.+]] = flow.dispatch.region
+//  CHECK-NEXT:         %[[GEN2:.+]] = linalg.generic
+//       CHECK:         flow.return %[[GEN2]]
+//       CHECK:       %[[RESULT:.+]] = flow.dispatch.region
+//       CHECK:         %[[ATTN:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:           ins(%[[DISPATCH0]], %[[DISPATCH1]], %[[DISPATCH2]]
+//       CHECK:         %[[GEN2:.+]] = linalg.generic
+//  CHECK-SAME:           ins(%[[ATTN]]
+//       CHECK:         flow.return %[[GEN2]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
index 079acf3..d0dc854 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
@@ -15,7 +15,7 @@
     linalg.yield %in : f16
   } -> tensor<4x32x64x128xf16>
   %4 = tensor.empty() : tensor<4x32x64x128xf16>
-  %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
+  %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
   %6 = tensor.empty() : tensor<4x64x32x128xf16>
   %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<4x32x64x128xf16>) outs(%6 : tensor<4x64x32x128xf16>) {
   ^bb0(%in: f16, %out: f16):
@@ -35,6 +35,55 @@
 //  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)>
 //  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)>
 //  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+//  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
+
+// -----
+
+util.func public @transposed_attention_masked(%arg0: tensor<4x64x32x128xf16>, %arg1: tensor<4x64x32x128xf16>, %arg2: tensor<4x64x32x128xf16>, %arg3: f16, %arg4: tensor<4x64x32x64xf16>) -> tensor<4x64x4096xf16> {
+  %0 = tensor.empty() : tensor<4x32x64x128xf16>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x32x64x128xf16>
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x32x64x128xf16>
+  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x32x64x128xf16>
+  %empty = tensor.empty() : tensor<4x32x64x64xf16>
+  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<4x64x32x64xf16>) outs(%empty : tensor<4x32x64x64xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x32x64x64xf16>
+  %5 = tensor.empty() : tensor<4x32x64x128xf16>
+  %6 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3, %4 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16, tensor<4x32x64x64xf16>) outs(%5 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
+  %7 = tensor.empty() : tensor<4x64x32x128xf16>
+  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<4x32x64x128xf16>) outs(%7 : tensor<4x64x32x128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x64x32x128xf16>
+  %collapsed = tensor.collapse_shape %8 [[0], [1], [2, 3]] : tensor<4x64x32x128xf16> into tensor<4x64x4096xf16>
+  util.return %collapsed : tensor<4x64x4096xf16>
+}
+
+// CHECK-LABEL: util.func public @transposed_attention_masked
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor
+//  CHECK-SAME:   %[[ARG2:[A-Za-z0-9]+]]: tensor
+//  CHECK-SAME:   %[[ARG3:[A-Za-z0-9]+]]: f16
+//  CHECK-SAME:   %[[ARG4:[A-Za-z0-9]+]]: tensor
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:     indexing_maps =
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>
 //  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
 //  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
 
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
index cbb6ebd..447ca7d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
@@ -15,6 +15,7 @@
     indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
     ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
     outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
@@ -39,26 +40,30 @@
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//   CHECK-DAG:   %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
-//   CHECK-DAG:   %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]]
-//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16>
-//   CHECK-DAG:   %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
-//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
-//   CHECK-DAG:   %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-//   CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}}
-//   CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}}
-//   CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}}
+//   CHECK-DAG:   %[[D:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[D]], %[[D0]], %[[D1]]) : tensor<?x?x?xf16>
 //       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
 //  CHECK-SAME:       indexing_maps =
-//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
-//  CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> ()>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] :
 //  CHECK-SAME:       outs(%[[EMPTY]] :
-//       CHECK:   util.return %[[ATTENTION]]
+//   CHECK-DAG:   %[[D_SPLIT:.+]] = arith.divsi %[[D]], %[[C2]]
+//   CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT]], %[[D0]], %[[D1]]]
+//   CHECK-DAG:   %[[OUTS:.+]] = tensor.empty(%[[D0]], %[[D_SPLIT]], %[[D1]]) : tensor<2x?x?x?xf16>
+//   CHECK-DAG:   %[[TRANSPOSE:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]
+//  CHECK-SAME:       ins(%[[EXPANDED]] :
+//  CHECK-SAME:       outs(%[[OUTS]] :
+//       CHECK:   linalg.yield
+//       CHECK:   util.return %[[TRANSPOSE]]
 
 // -----
 
@@ -70,6 +75,7 @@
       indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> ()>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
       ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
       outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
@@ -91,16 +97,23 @@
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
 //  CHECK-SAME:     %[[ARG3:.+]]: f16)
-//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
-//   CHECK-DAG:   %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
-//   CHECK-DAG:   %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
-//   CHECK-DAG:   %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<20x4096x64xf16>
 //       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
 //  CHECK-SAME:       indexing_maps =
-//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
-//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
-//  CHECK-SAME:       ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> ()>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] :
 //  CHECK-SAME:       outs(%[[EMPTY]] :
-//       CHECK:   util.return %[[ATTENTION]]
+//   CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 64]
+//   CHECK-DAG:   %[[OUTS:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
+//   CHECK-DAG:   %[[TRANSPOSE:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:           [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+//  CHECK-SAME:            affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]
+//  CHECK-SAME:       ins(%[[EXPANDED]] :
+//  CHECK-SAME:       outs(%[[OUTS]] :
+//       CHECK:   linalg.yield
+//       CHECK:   util.return %[[TRANSPOSE]]
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
index d258dc1..feaad0e 100644
--- a/tests/e2e/attention/generate_e2e_attention_tests.py
+++ b/tests/e2e/attention/generate_e2e_attention_tests.py
@@ -217,6 +217,7 @@
         f"      indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
         f"                       affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
         f"                       affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
+        f"                       affine_map<(batch, m, n, k1, k2) -> ()>,\n"
         f"                       affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
         f"      ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n"
         f"      outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir
index c418809..c2ca832 100644
--- a/tests/e2e/linalg_ext_ops/attention.mlir
+++ b/tests/e2e/linalg_ext_ops/attention.mlir
@@ -14,6 +14,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x3x4xf32>,
         tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%init : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
@@ -44,6 +45,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<1x4x4xf32>,
         tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
@@ -90,6 +92,7 @@
   %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
                      ins(%query, %key, %value, %scale : tensor<3x3x4xf32>,
         tensor<3x3x4xf32>, tensor<3x3x4xf32>, f32) outs(%init : tensor<3x3x4xf32>) -> tensor<3x3x4xf32>