[LinalgExt] Add scaling to attention op (#16679)

This commit add support for a scale input to the attention
op. This is to follow semantics of scaled dot product
semantics for attention like explained in
https://paperswithcode.com/method/scaled

Right now we only thread through the lower part of the
stack, meaning linalg ext ops and its code generation.
We also need to eventually fix upper layers to get the
scale from the input. For now we only support the PyTorch
default scale case--rsqrt(d).

---------

Co-authored-by: Lei Zhang <antiagainst@gmail.com>
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index e09363f..05988db 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -11,6 +11,8 @@
 #include "compiler/plugins/input/Torch/InputConversion/Passes.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -145,8 +147,36 @@
     Value collapsedResult = rewriter.create<tensor::EmptyOp>(
         loc, collapsedResultShape, elementType);
 
+    // TODO: This is a hack. This should be replaced with a simple getScale()
+    // when support for scaling is plumbed to TMTensor on the torch-mlir side.
+    // Until then, we are using the default value used in scaled dot product
+    // attention by PyTorch (most models use the default value because it makes
+    // the variance of the result of softmax 1 when the mean of Q, K is 0).
+    // We use scale = 1 / sqrt(d), where d is the head dimension.
+    // See https://paperswithcode.com/method/scaled for more details.
+    //
+    // TODO: We are currently assuming that head dimension is dim = -1. Once we
+    // have support for batch dims using more general indexing maps, we should
+    // change this and rely on more general mechanisms.
+    // TODO: We are currently not handling dynamic shape of head dimensions at
+    // all. This is because it messes with dispatch formation. This should be
+    // fixed.
+    ArrayRef<int64_t> queryShape = op.getQueryType().getShape();
+    int64_t headDim = queryShape.back();
+    if (headDim == ShapedType::kDynamic) {
+      return op->emitOpError("NYI: Dynamic head dimension");
+    }
+
+    // Attention only works for FloatType.
+    FloatType targetType = cast<FloatType>(op.getQueryType().getElementType());
+
+    double dk = static_cast<double>(headDim);
+    dk = 1.0 / std::sqrt(dk);
+    Value scale = rewriter.create<arith::ConstantOp>(
+        loc, targetType, rewriter.getFloatAttr(targetType, dk));
+
     auto attention = rewriter.create<IREE::LinalgExt::AttentionOp>(
-        loc, collapsedResultType, SmallVector<Value>{query, key, value},
+        loc, collapsedResultType, SmallVector<Value>{query, key, value, scale},
         collapsedResult);
 
     if (sizes.size() > 3)
diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
index 0d895ab..17000c0 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
@@ -1,7 +1,4 @@
-// RUN: iree-opt --split-input-file --torch-iree-tm-tensor-to-linalg-ext %s | FileCheck %s
-
-// https://github.com/openxla/iree/issues/14916
-// XFAIL: *
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-tm-tensor-to-linalg-ext))" %s | FileCheck %s
 
 func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %arg2: tensor<5x2x3x4xf32>, %arg3: tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xf32>) {
   %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>) outs(%arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
@@ -11,11 +8,12 @@
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
 // CHECK:         %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
+// CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
 // CHECK:         %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
 // CHECK:         %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<10x3x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]] : tensor<10x3x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>) outs(%[[EMPTY]] : tensor<10x3x4xf32>) -> tensor<10x3x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x3x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x3x4xf32>) -> tensor<10x3x4xf32>
 // CHECK:         %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x3x4xf32> into tensor<5x2x3x4xf32>
 // CHECK:         return %[[RET]] : tensor<5x2x3x4xf32>
 
@@ -28,11 +26,12 @@
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
 // CHECK:         %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
+// CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x8x4xf32> into tensor<10x8x4xf32>
 // CHECK:         %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
 // CHECK:         %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<10x8x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]] : tensor<10x8x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>) outs(%[[EMPTY]] : tensor<10x8x4xf32>) -> tensor<10x8x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x8x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x8x4xf32>) -> tensor<10x8x4xf32>
 // CHECK:         %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x8x4xf32> into tensor<5x2x8x4xf32>
 // CHECK:         return %[[RET]] : tensor<5x2x8x4xf32>
 
@@ -45,6 +44,7 @@
 // CHECK-LABEL:         func.func @attention(
 // CHECK-SAME:         %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>,
 // CHECK:         %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
+// CHECK:         %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK:         %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
-// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
+// CHECK:         %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
 // CHECK:         return %[[ATTN]] : tensor<1x3x4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index d295614..7af1e58 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -2334,6 +2334,7 @@
   builtin.module {
     func.func @attention() {
       %c0 = arith.constant 0 : index
+      %scale = arith.constant 0.125 : f16
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
       %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
@@ -2343,7 +2344,7 @@
       %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
       %7 = tensor.empty() : tensor<20x4096x64xf16>
       %8 = iree_linalg_ext.attention
-        ins(%4, %5, %6 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>)
+        ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
         outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
       flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
       return
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index ce450ba..dc7c533 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -12,6 +12,7 @@
     builtin.module {
       func.func @_attention_dispatch_0() {
         %c0 = arith.constant 0 : index
+        %scale = arith.constant 0.125 : f16
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
@@ -20,7 +21,7 @@
         %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
         %7 = tensor.empty() : tensor<192x1024x64xf16>
-        %8 = iree_linalg_ext.attention ins(%4, %5, %6 : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
+        %8 = iree_linalg_ext.attention ins(%4, %5, %6, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
         flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
         return
       }
@@ -47,6 +48,7 @@
 // CHECK-DAG:    %[[C128:.+]] = arith.constant 128 : index
 // CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
 // CHECK-DAG:    %[[CST_5:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-dAG:    %[[CST_6:.+]] = arith.constant dense<1.802980e-01> : vector<128x64xf16>
 // CHECK:        %[[D0:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64)
 // CHECK-SAME:     offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
 // CHECK:        memref.assume_alignment %[[D0]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
@@ -92,7 +94,6 @@
 // CHECK-DAG:    %[[D8:.+]] = affine.apply #[[MAP2]]()[%[[D5]], %[[D6]], %[[D7]]]
 // CHECK:        %[[D9:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D8]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
 // CHECK-SAME:     true]} : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>, vector<32x64xf16>
-// CHECK:        %[[D10:.+]] = arith.extf %[[D9]] : vector<32x64xf16> to vector<32x64xf32>
 // CHECK:        %[[D11:.+]]:3 = scf.for %[[ARG0:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C128]]
 // CHECK-SAME:     iter_args(%[[ARG1:[a-zA-Z0-9_]+]] = %[[CST_0]], %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST_1]],
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<32xf32>, vector<32xf32>, vector<32x64xf32>) {
@@ -100,6 +101,8 @@
 // CHECK-SAME:       : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
 // CHECK:          %[[SUBVIEW_9:.+]] = memref.subview %[[D2]][%[[WORKGROUP_ID_X]], %[[ARG0]], 0] [1, 128, 64] [1, 1, 1]
 // CHECK-SAME:       : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
+// CHECK:          %[[ALLOC_12:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16, #gpu.address_space<workgroup>>
+// CHECK:          vector.transfer_write %[[CST_6:.+]], %[[ALLOC_12]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<128x64xf16>, memref<128x64xf16, #gpu.address_space<workgroup>>
 // CHECK:          %[[ALLOC_10:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16,
 // CHECK-SAME:       #[[GPU]].address_space<workgroup>>
 // CHECK:          gpu.barrier
@@ -120,8 +123,11 @@
 // CHECK:            linalg.yield %[[IN]] : f16
 // CHECK:          }
 // CHECK:          gpu.barrier
+// CHECK:          %[[READ:.+]] = vector.transfer_read %[[ALLOC_12]][%[[D8]], %[[C0]]], %{{.+}} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<32x64xf16>
+// CHECK:          %[[MUL:.+]] = arith.mulf %[[D9]], %[[READ]] : vector<32x64xf16>
 // CHECK:          %[[D13:.+]] = vector.transfer_read %[[ALLOC_10]][%[[C0]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
 // CHECK-SAME:       true]} : memref<128x64xf16, #[[GPU]].address_space<workgroup>>, vector<128x64xf16>
+// CHECK:          %[[D10:.+]] = arith.extf %[[MUL]] : vector<32x64xf16> to vector<32x64xf32>
 // CHECK:          %[[D14:.+]] = arith.extf %[[D13]] : vector<128x64xf16> to vector<128x64xf32>
 // CHECK:          %[[D15:.+]] = vector.contract {indexing_maps = [#[[MAP4]], #[[MAP5]], #[[MAP6]]], iterator_types =
 // CHECK-SAME:       ["parallel", "parallel", "reduction"], kind = #[[VECTOR:.+]].kind<add>} %[[D10]], %[[D14]],
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
index a871167..6148fc3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
@@ -17,6 +17,7 @@
         // CHECK: scf.for {{.*}} = %c0 to %c16384 step %c64 {{.*}} -> (vector<2xf32>, vector<2xf32>, vector<8x2x4xf32>)
         // CHECK-COUNT-128: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32}
         %c0 = arith.constant 0 : index
+        %scale = arith.constant 0.08838834764 : f16
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
@@ -25,7 +26,7 @@
         %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
         %7 = tensor.empty() : tensor<16x16384x128xf16>
-        %8 = iree_linalg_ext.attention ins(%4, %5, %6 : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
+        %8 = iree_linalg_ext.attention ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
         flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
         return
       }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
index 9bf926b..486d2b9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
@@ -34,9 +34,10 @@
     %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 {tile_size = 64} :
       (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
+    %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
         = transform.iree.decompose_tiled_attention %blocked_attention {tile_size = 64} :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
+                              !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 
     // Promote key and value operands
     // ==========================================
@@ -81,9 +82,18 @@
     } : !transform.any_op
     transform.apply_cse to %func : !transform.any_op
 
+    %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.apply_cse to %func : !transform.any_op
     // Distribute fills
     // ==========================================
-    %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
+
+    // Get all fills that haven't been distributed to warps.
+    %fills = transform.include @get_undistributed_fills failures(propagate) (%variant_op)  : (!transform.any_op) -> !transform.any_op
+
     %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     // Distribute last_truncate and fuse final_scaling into it
@@ -153,9 +163,11 @@
 
     // Get the vector.contract ops.
     %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op_3 :  (!transform.any_op) -> !transform.any_op
+    %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
-    transform.iree.set_contraction_layout_attributes %contracts, %layout16x16x16 : !transform.any_op, !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 : !transform.any_op, !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param
 
     %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
     transform.iree.amdgpu_distribute_vectors %distribute_func test_conversion : !transform.any_op
@@ -178,4 +190,14 @@
 
     transform.yield
   }
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %arg0 ["linalg.fill"] : !transform.any_op
+    %0 = transform.get_parent_op %arg0 {allow_empty_results, nth_parent = 2 : i64, op_name = "scf.forall"} : (!transform.any_op) -> !transform.any_op
+    transform.match.operation_empty %0 : !transform.any_op
+    transform.yield %arg0 : !transform.any_op
+  }
+  transform.named_sequence @get_undistributed_fills(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield %0 : !transform.any_op
+  }
 } ////  module
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
index 1fd8749..f819a30 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
@@ -24,7 +24,7 @@
     // Promote query and output operands
     // ==========================================
     %attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 3]
+    %promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 4]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
 
     // Tile and decompose attention
@@ -32,9 +32,10 @@
     %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 :
       (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
+    %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
         = transform.iree.decompose_tiled_attention %blocked_attention :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
+                              !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 
     // Promote key and value operands
     // ==========================================
@@ -79,6 +80,13 @@
     } : !transform.any_op
     transform.apply_cse to %func : !transform.any_op
 
+    %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.apply_cse to %func : !transform.any_op
+
     // Distribute fills
     // ==========================================
     %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 95003be..6629920 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
@@ -2209,7 +2210,6 @@
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
-
   Location loc = getLoc();
   auto one = builder.getIndexAttr(1);
   auto zero = builder.getIndexAttr(0);
@@ -2381,7 +2381,6 @@
 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
-
   Location loc = getLoc();
   auto one = builder.getIndexAttr(1);
   auto zero = builder.getIndexAttr(0);
@@ -2460,8 +2459,8 @@
 /// Utility function to check whether a given ShapedType has the expected rank.
 static LogicalResult checkShapeRank(Operation *op, StringRef operandName,
                                     ShapedType shapedType,
-                                    unsigned rankToCompareWith) {
-  unsigned opRank = shapedType.getRank();
+                                    int64_t rankToCompareWith) {
+  int64_t opRank = shapedType.getRank();
   if (opRank != rankToCompareWith) {
     return op->emitOpError("expected ")
            << operandName << " to have rank " << rankToCompareWith
@@ -2472,13 +2471,36 @@
 
 LogicalResult AttentionOp::verify() {
   Operation *op = getOperation();
-  unsigned numOperands = getNumOperands();
-  unsigned rankToCompareWith = 3;
-  if (numOperands == 6)
+
+  int numInputs = getNumDpsInputs();
+  int numOutputs = getNumDpsInits();
+
+  if (numInputs != 4) {
+    return op->emitOpError(
+        "expected 4 input operands: Query, Key, Value and Scale");
+  }
+
+  if (numOutputs != 1 && numOutputs != 3) {
+    return op->emitOpError(
+        "expected 1 or 3 output operands: Output, [Max and Sum]");
+  }
+
+  bool isTiled = numOutputs == 3;
+
+  int64_t rankToCompareWith;
+  if (isTiled) {
     rankToCompareWith = 2;
-  else if (numOperands != 4)
-    return op->emitOpError("expected operand count 4 or 6, but got")
-           << numOperands;
+  } else {
+    rankToCompareWith = 3;
+  }
+
+  if (!llvm::all_of(llvm::drop_end(getDpsInputs()), [](Value input) {
+        return isa<ShapedType>(input.getType());
+      })) {
+    return op->emitOpError(
+        "expected Query, Key, Value inputs to be of shaped type");
+  }
+
   ShapedType queryType = getQueryType();
   ShapedType keyType = getKeyType();
   ShapedType valueType = getValueType();
@@ -2487,6 +2509,12 @@
   Type keyElementType = keyType.getElementType();
   Type valueElementType = valueType.getElementType();
   Type outputElementType = outputType.getElementType();
+
+  FloatType scaleElementType = dyn_cast<FloatType>(getScale().getType());
+  if (!scaleElementType) {
+    return op->emitOpError("expected scale to be of floating point type");
+  }
+
   if (failed(checkShapeRank(op, "query", queryType, rankToCompareWith))) {
     return failure();
   }
@@ -2515,11 +2543,13 @@
     return op->emitOpError("incompatible output shape");
   }
   if (queryElementType != keyElementType ||
-      keyElementType != valueElementType) {
+      queryElementType != valueElementType ||
+      queryElementType != scaleElementType) {
     return op->emitOpError(
-        "element types of (Q)uery, (K)ey and (V)value should be same");
+        "element types of (Q)uery, (K)ey and (V)alue and scale should be "
+        "same");
   }
-  if (numOperands == 4) {
+  if (!isTiled) {
     // Vanilla attention.
     if (queryElementType != outputElementType) {
       return op->emitOpError("expected element type for Output ")
@@ -2530,7 +2560,7 @@
       return op->emitOpError("query and key head dimension mismatch");
     }
   }
-  if (numOperands == 6) {
+  if (isTiled) {
     // Tiled/Flash attention.
     ShapedType maxType = *getMaxType();
     ShapedType sumType = *getSumType();
@@ -2556,6 +2586,7 @@
       return op->emitOpError("Query and max dimension-0 mismatch");
     }
   }
+
   return success();
 }
 
@@ -2609,6 +2640,8 @@
   keyValueSizes[0] = sizes[0];
   keyValueOffsets[0] = offsets[0];
 
+  Value scale = getScale();
+
   SmallVector<Value> tiledOperands;
   tiledOperands.emplace_back(getSlice(builder, loc, getQuery(),
                                       queryOutputOffsets, queryOutputSizes,
@@ -2617,13 +2650,14 @@
                                       keyValueSizes, keyValueStrides));
   tiledOperands.emplace_back(getSlice(builder, loc, getValue(), keyValueOffsets,
                                       keyValueSizes, keyValueStrides));
+  tiledOperands.emplace_back(scale);
   tiledOperands.emplace_back(getSlice(builder, loc, getOutput(),
                                       queryOutputOffsets, queryOutputSizes,
                                       queryOutputStrides));
 
   SmallVector<Type> resultTypes;
   if (hasPureTensorSemantics())
-    resultTypes.push_back(tiledOperands[3].getType());
+    resultTypes.push_back(tiledOperands[4].getType());
 
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 3620746..4f9ad3a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1,4 +1,4 @@
-  // Copyright 2021 The IREE Authors
+// Copyright 2021 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -520,26 +520,41 @@
        "getTiledImplementation"]>]> {
   let summary = "Attention operator";
   let description = [{
-    This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
-    the attention. For self-attention, all inputs have the same shape BxNxd where B is the
-    of the batch dimension, N is the sequence length and d is head dimension.
-    Typically N >>> d. Mathematically, the attention is defined as
-    matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
-    this operator also performs scaling, masking and dropout, but we leave
-    that out of the current implementation. For cross-attention, the query and output
-    have the same shape (BxNxd), while the key and value differ in sequence length
-    (they have shape BxLxd, where L != N).
-    This operator after tiling results in a tiled result as per flash attention and results
-    in the current `max` and `sum` statistics while processing the current tile.
+    Computes the scaled dot product attention function:
+
+    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
+
+    Here Q, K, V are given tensors and scale is a scalar value specifying
+    the scale to use.
+
+    For self-attention, all inputs and the result have the same shape BxNxd
+    where B is the batch dimension, N is the sequence length and d is head
+    dimension. Typically N >>> d. Usually, this operator also performs
+    masking and dropout, but we leave that out of the current implementation.
+    For cross-attention, the query and output have the same shape (BxNxd),
+    while the key and value differ in sequence length (they have shape BxLxd,
+    where L != N).
+
+    This operator after tiling results in a tiled result as per
+    FlashAttention 2 and optionally results in the current `max` and `sum`
+    statistics while processing the current tile.
+
+    If transpose_v is speciifed, the V tensor passed as input is assumed to
+    be transposed:
+
+    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T
+
+    TODO: We should be moving to using a indexing map like approach so we
+    can generalize which tensor is transposed and which is not.
   }];
 
-  let arguments = (ins Variadic<AnyShaped>:$inputs,
+  let arguments = (ins Variadic<AnyType>:$inputs,
                        Variadic<AnyShaped>:$outputs,
                        DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_v
   );
 
   let builders = [
-    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>,
   ];
 
   let results = (outs Variadic<AnyRankedTensor>:$results);
@@ -561,16 +576,19 @@
     Value getValue() {
       return getDpsInputOperand(2)->get();
     }
+    Value getScale() {
+      return getDpsInputOperand(3)->get();
+    }
     Value getOutput() {
       return getDpsInitOperand(0)->get();
     }
     std::optional<Value> getMax() {
-      if (getNumResults() == 1)
+      if (getNumResults() < 2)
         return std::nullopt;
       return getDpsInitOperand(1)->get();
     }
     std::optional<Value> getSum() {
-      if (getNumResults() == 1)
+      if (getNumResults() < 3)
         return std::nullopt;
       return getDpsInitOperand(2)->get();
     }
@@ -583,18 +601,21 @@
     ShapedType getValueType() {
       return getValue().getType().cast<ShapedType>();
     }
+    FloatType getScaleType() {
+      return getScale().getType().cast<FloatType>();
+    }
     ShapedType getOutputType() {
       return getOutput().getType().cast<ShapedType>();
     }
     std::optional<ShapedType> getMaxType() {
-      if (!getMax().has_value())
-        return std::nullopt;
-      return (*getMax()).getType().cast<ShapedType>();
+      std::optional<Value> maxVal = getMax();
+      if (!maxVal) return std::nullopt;
+      return maxVal->getType().cast<ShapedType>();
     }
     std::optional<ShapedType> getSumType() {
-      if (!getSum().has_value())
-        return std::nullopt;
-      return (*getSum()).getType().cast<ShapedType>();
+      std::optional<Value> sumVal = getSum();
+      if (!sumVal) return std::nullopt;
+      return sumVal->getType().cast<ShapedType>();
     }
     int64_t getQueryRank() {
       return getQueryType().getRank();
@@ -609,14 +630,14 @@
       return getOutputType().getRank();
     }
     std::optional<int64_t> getMaxRank() {
-      if (!getMax())
-        return std::nullopt;
-      return (*getMaxType()).getRank();
+      std::optional<ShapedType> maxType = getMaxType();
+      if (!maxType) return std::nullopt;
+      return maxType->getRank();
     }
     std::optional<int64_t> getSumRank() {
-      if (!getSum().has_value())
-        return std::nullopt;
-      return (*getSumType()).getRank();
+      std::optional<ShapedType> sumType = getSumType();
+      if (!sumType) return std::nullopt;
+      return sumType->getRank();
     }
     int64_t getIterationDomainRank() {
       return 2;
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 8c90638..efce33b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -725,8 +725,9 @@
 
 func.func @illegal_attention_inputs(%query: tensor<6x12x20x8xf32>, %key: tensor<6x12x20x8xf32>, %value: tensor<6x12x20x8xf32>) {
   %0 = tensor.empty() : tensor<6x12x20x8xf32>
+  %scale = arith.constant 1.0 : f32
   // expected-error @+1 {{'iree_linalg_ext.attention' op expected query to have rank 3 but found 4}}
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32>
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32>
   return %1 : tensor<6x12x20x8xf32>
 }
 
@@ -736,9 +737,18 @@
   %result = tensor.empty() : tensor<20x8xf32>
   %max = tensor.empty() : tensor<8xf32>
   %sum = tensor.empty() : tensor<8xf32>
+  %scale = arith.constant 1.0 : f32
   // expected-error @+1 {{'iree_linalg_ext.attention' op expected query to have rank 2 but found 1}}
-  %1:3 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
+  %1:3 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
   return %1#0, %1#1, %1#2 : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
 }
 
 // -----
+
+func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: f32) -> tensor<192x1024x64xf32> {
+  %0 = tensor.empty() : tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  // expected-error @+1 {{expected Query, Key, Value inputs to be of shaped type}}
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  return %1 : tensor<192x1024x64xf32>
+}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index a5a006e..8d6a575 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1083,15 +1083,17 @@
 
 func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
 }
 // CHECK:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK-SAME:   {
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>) outs(%[[D0]] :
+// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
 // CHECK:      }
@@ -1100,15 +1102,17 @@
 
 func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
 }
 // CHECK:      func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK-SAME:   {
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>) outs(%[[D0]] :
+// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
 // CHECK:      }
@@ -1117,15 +1121,17 @@
 
 func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
-  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
 }
 // CHECK:      func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32>
 // CHECK-SAME:   {
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%[[D0]] :
+// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
 // CHECK:      }
@@ -1133,14 +1139,16 @@
 // -----
 
 func.func @cross_attention_transposev_dyn(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
 // CHECK:      func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK-SAME:   {
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
-// CHECK-SAME:     tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG3]] :
+// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:     tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
 // CHECK-SAME:     tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK:        return %[[D1]] : tensor<?x?x?xf32>
 // CHECK:      }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/tiling.mlir
index b474608..31db8b2 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/tiling.mlir
@@ -1020,7 +1020,8 @@
 
 func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
   %0 = tensor.empty() : tensor<192x1024x64xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
   return %1 : tensor<192x1024x64xf32>
 }
 module attributes { transform.with_named_sequence } {
@@ -1036,6 +1037,7 @@
 // CHECK-SAME:   tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK-SAME:   {
 // CHECK-DAG:    %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG:    %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:    %[[C192:.+]] = arith.constant 192 : index
 // CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
@@ -1056,7 +1058,7 @@
 // CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] [%[[D2]],
 // CHECK-SAME:         %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
 // CHECK:            %[[D5:.+]] = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
-// CHECK-SAME:         %[[EXTRACTED_SLICE_1]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>)
+// CHECK-SAME:         %[[EXTRACTED_SLICE_1]], %[[C1_F32]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32)
 // CHECK-SAME:         outs(%[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
 // CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0]
 // CHECK-SAME:         [%[[D2]], %[[D4]], 64] [1, 1, 1] : tensor<?x?x64xf32> into tensor<192x1024x64xf32>
@@ -1070,7 +1072,8 @@
 // -----
 
 func.func @attention_memref(%query: memref<192x1024x64xf32>, %key: memref<192x1024x64xf32>, %value: memref<192x1024x64xf32>, %output: memref<192x1024x64xf32>) {
-  iree_linalg_ext.attention ins(%query, %key, %value : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>) outs(%output : memref<192x1024x64xf32>)
+  %scale = arith.constant 1.0 : f32
+  iree_linalg_ext.attention ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>)
   return
 }
 module attributes { transform.with_named_sequence } {
@@ -1085,11 +1088,12 @@
 // CHECK:      func.func @attention_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   memref<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   memref<192x1024x64xf32>) {
-// CHECK:        %[[C30:.+]] = arith.constant 30 : index
-// CHECK:        %[[C0:.+]] = arith.constant 0 : index
-// CHECK:        %[[C192:.+]] = arith.constant 192 : index
-// CHECK:        %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK:        %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG:    %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:        %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:        %[[C192:.+]] = arith.constant 192 : index
+// CHECK-DAG:        %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK-DAG:        %[[C10:.+]] = arith.constant 10 : index
 // CHECK:        scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]] {
 // CHECK:          scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]] {
 // CHECK-DAG:        %[[D0:.+]] = affine.min #[[MAP]](%[[ARG4]])
@@ -1102,9 +1106,9 @@
 // CHECK-SAME:         memref<192x1024x64xf32> to memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>
 // CHECK:            %[[SUBVIEW_2:.+]] = memref.subview %[[ARG3]][%[[ARG4]], %[[ARG5]], 0] [%[[D0]], %[[D1]], 64] [1, 1,
 // CHECK-SAME:         1] : memref<192x1024x64xf32> to memref<?x?x64xf32, strided<[65536, 64, 1], offset: ?>>
-// CHECK:            iree_linalg_ext.attention ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]] : memref<?x?x64xf32,
+// CHECK:            iree_linalg_ext.attention ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]], %[[C1_F32]] : memref<?x?x64xf32,
 // CHECK-SAME:         strided<[65536, 64, 1], offset: ?>>, memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>,
-// CHECK-SAME:         memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>) outs(%[[SUBVIEW_2]] :
+// CHECK-SAME:         memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>, f32) outs(%[[SUBVIEW_2]] :
 // CHECK-SAME:         memref<?x?x64xf32, strided<[65536, 64, 1], offset: ?>>)
 // CHECK:          }
 // CHECK:        }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAndDecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAndDecomposeAttention.cpp
index 4c15d75..657e16f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAndDecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAndDecomposeAttention.cpp
@@ -8,16 +8,12 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -437,10 +433,12 @@
   Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
                                   headDimension, elementType, loc, rewriter);
 
+  Value scale = attnOp.getScale();
+
   auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
       attnOp.getLoc(),
       SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
-      SmallVector<Value>{querySlice, keySlice, valueSlice},
+      SmallVector<Value>{querySlice, keySlice, valueSlice, scale},
       SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
 
   if (attnOp.getTransposeV())
@@ -478,6 +476,34 @@
   return tiledAttentionOp;
 }
 
+Value scaleQuery(Value querySlice, Value scale, RewriterBase &rewriter) {
+  ShapedType queryType = cast<ShapedType>(querySlice.getType());
+  Location loc = querySlice.getLoc();
+
+  // Create a fill op for scale.
+  SmallVector<OpFoldResult> queryDims =
+      tensor::getMixedSizes(rewriter, loc, querySlice);
+  Value empty = rewriter.create<tensor::EmptyOp>(loc, queryDims,
+                                                 queryType.getElementType());
+  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{scale}, empty)
+                    .getResult(0);
+
+  // Create a generic op to multiply the query by the scale.
+  SmallVector<utils::IteratorType> iteratorTypes(2,
+                                                 utils::IteratorType::parallel);
+  auto identityMap =
+      AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+  SmallVector<AffineMap> indexingMaps(2, identityMap);
+  auto scaleOp = rewriter.create<linalg::GenericOp>(
+      loc, TypeRange{fillOp.getType()}, ValueRange{querySlice},
+      ValueRange{fillOp}, indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return scaleOp.getResult(0);
+}
+
 /// Decompose tiled iree_linalg_ext.attention op.
 /// TODO: Adopt decomposeOperation with this.
 void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
@@ -492,9 +518,6 @@
   Value max = *tiledAttnOp.getMax();
   Value sum = *tiledAttnOp.getSum();
 
-  assert(max && "expected max statistic operand to be present");
-  assert(sum && "expected sum statistic operand to be present");
-
   OpBuilder::InsertionGuard withinScfLoop(rewriter);
   rewriter.setInsertionPointAfter(tiledAttnOp);
   SmallVector<OpFoldResult> queryDimValues =
@@ -505,6 +528,24 @@
       tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
 
   Type elementType = tiledAttnOp.getQueryType().getElementType();
+
+  // Since we use exp2 for attention instead of the original exp, we have to
+  // multiply the scale by log2(e). We use exp2 instead of exp as most GPUs
+  // have better support for exp2.
+  Value scale = tiledAttnOp.getScale();
+  Value log2e = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getFloatAttr(elementType, M_LOG2E));
+  scale = rewriter.create<arith::MulFOp>(loc, scale, log2e);
+
+  // In the original algorithm, the scaling is done after the softmax:
+  //        softmax(Q @ K.T * scale) @ V
+  //
+  // But, it is mathematically equivalent to do it on Q first and then multiply
+  // it by K.T. This just allows us to do the scaling once, instead of each
+  // iteration of the loop.
+  querySlice = scaleQuery(querySlice, scale, rewriter);
+  ops.push_back(querySlice.getDefiningOp());
+
   auto [result, newMax, newSum] = createAttentionBody(
       keySlice, valueSlice, querySlice, tiledResult, max, sum,
       sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_and_decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_and_decompose_attention.mlir
index 0aa8de7..91424fc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_and_decompose_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_and_decompose_attention.mlir
@@ -4,7 +4,8 @@
 
 func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
   %0 = tensor.empty() : tensor<1x1024x64xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
+  %scale = arith.constant 0.05 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
   return %1 : tensor<1x1024x64xf32>
 }
 
@@ -34,10 +35,11 @@
 // TILESIZE-SAME:       tensor<1x1024x64xf32> to tensor<32x64xf32>
 // TILESIZE:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILESIZE-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
+// TILESIZE:          %[[SCALE_Q:.+]] = linalg.generic {{.+}} ins(%[[EXTRACTED_SLICE_2]] : tensor<1024x64xf32>)
 // TILESIZE:          %[[D8:.+]] = tensor.empty() : tensor<1024x32xf32>
 // TILESIZE:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x32xf32>) ->
 // TILESIZE-SAME:       tensor<1024x32xf32>
-// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
 // TILESIZE-SAME:       tensor<1024x64xf32>, tensor<32x64xf32>) outs(%[[D9]] : tensor<1024x32xf32>) -> tensor<1024x32xf32>
 // TILESIZE:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
 // TILESIZE-SAME:       "reduction"]} ins(%[[D10]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
@@ -86,8 +88,8 @@
 // TILESIZE-SAME:     {
 // TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// TILESIZE:          %[[D8]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// TILESIZE:          %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE:          %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
 // TILESIZE:          linalg.yield %[[D9]] : f32
 // TILESIZE:        } -> tensor<1024x64xf32>
 // TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
@@ -99,9 +101,10 @@
 // TILING-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 // TILING:      func.func @attention
 // TILING-SAME: (%[[QUERY:.+]]: tensor<1x1024x64xf32>, %[[KEY:.+]]: tensor<1x1024x64xf32>, %[[VALUE:.+]]: tensor<1x1024x64xf32>)
-// TILING:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// TILING:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILING-DAG:    %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
+// TILING-DAG:    %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
 // TILING-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILING-DAG:    %[[CST_1:.+]] = arith.constant 5.000000e-02 : f32
 // TILING:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
 // TILING-SAME:     tensor<1024x64xf32>
 // TILING-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
@@ -119,7 +122,7 @@
 // TILING-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
 // TILING:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILING-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
-// TILING:          %[[TILED_ATTENTION:.+]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_1]] :
+// TILING:          %[[TILED_ATTENTION:.+]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_1]], %[[CST_1]] :
 // TILING-SAME:                                           outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] :
 // TILING-SAME:                                           -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
@@ -143,9 +146,10 @@
 // CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
 // CHECK:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
 // CHECK-SAME:   tensor<1x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// CHECK-DAG:    %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
+// CHECK-DAG:    %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
 // CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CST_1:.+]] = arith.constant 5.000000e-02 : f32
 // CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
 // CHECK-SAME:     tensor<1024x64xf32>
 // CHECK-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
@@ -163,10 +167,19 @@
 // CHECK-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
 // CHECK:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
+// CHECK:          %[[LOG2E:.+]] = arith.constant 1.44269502 : f32
+// CHECK:          %[[MUL:.+]] = arith.mulf %[[CST_1]], %[[LOG2E]] : f32
+// CHECK:          %[[FILL:.+]] = linalg.fill ins(%[[MUL]] : f32) outs(%1 : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// CHECK:          %[[SCALE_Q:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME:       ins(%[[EXTRACTED_SLICE_2]] : tensor<1024x64xf32>) outs(%[[FILL]] : tensor<1024x64xf32>) {
+// CHECK:          ^{{.+}}(%in: f32, %out: f32):
+// CHECK:            %[[MUL:.+]] = arith.mulf %in, %out : f32
+// CHECK:            linalg.yield %[[MUL]] : f32
+// CHECK:          } -> tensor<1024x64xf32>
 // CHECK:          %[[D8:.+]] = tensor.empty() : tensor<1024x1024xf32>
 // CHECK:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
 // CHECK-SAME:       tensor<1024x64xf32>, tensor<1024x64xf32>) outs(%[[D9]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
 // CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
@@ -228,7 +241,8 @@
 
 func.func @attention(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<?x?x?xf32> {
   %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %scale = arith.constant 0.05 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
 
@@ -263,9 +277,10 @@
 // TILESIZE:          %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
 // TILESIZE-SAME:       1] : tensor<?x?x?xf32> to tensor<?x?xf32>
 // TILESIZE:          %[[DIM_5:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
+// TILESIZE:          %[[SCALE_Q:.+]] = linalg.generic
 // TILESIZE:          %[[D8:.+]] = tensor.empty(%[[DIM_5]]) : tensor<?x32xf32>
 // TILESIZE:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<?x32xf32>) -> tensor<?x32xf32>
-// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]] :
+// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
 // TILESIZE-SAME:       tensor<?x?xf32>, tensor<32x?xf32>) outs(%[[D9]] : tensor<?x32xf32>) -> tensor<?x32xf32>
 // TILESIZE:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
 // TILESIZE-SAME:       "reduction"]} ins(%[[D10]] : tensor<?x32xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
@@ -313,8 +328,8 @@
 // TILESIZE-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
 // TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
-// TILESIZE:          %[[D8]] = arith.divf %[[CST_3]], %[[IN]] : f32
-// TILESIZE:          %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE:          %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// TILESIZE:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
 // TILESIZE:          linalg.yield %[[D9]] : f32
 // TILESIZE:        } -> tensor<?x?xf32>
 // TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
@@ -348,7 +363,7 @@
 // TILING-SAME:       [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?xf32>
 // TILING:          %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
 // TILING-SAME:       1] : tensor<?x?x?xf32> to tensor<?x?xf32>
-// TILING:          %[[TILED_ATTENTION]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_3]] :
+// TILING:          %[[TILED_ATTENTION]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_3]], %{{[a-z0-1]+}} :
 // TILING-SAME:                      outs(%[[ARG7]], %[[ARG8]], %[[ARG9]] :
 // TILING-SAME:                      -> tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
@@ -396,9 +411,10 @@
 // CHECK:          %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
 // CHECK-SAME:       1] : tensor<?x?x?xf32> to tensor<?x?xf32>
 // CHECK:          %[[QUERY_SLICE_DIM_0:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
+// CHECK:          %[[SCALE_Q:.+]] = linalg.generic
 // CHECK:          %[[D7:.+]] = tensor.empty(%[[QUERY_SLICE_DIM_0]], %[[QUERY_SLICE_DIM_0]]) : tensor<?x?xf32>
 // CHECK:          %[[D8:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D7]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK:          %[[D9:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]] :
+// CHECK:          %[[D9:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
 // CHECK-SAME:       tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[D8]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK:          %[[D10:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
 // CHECK-SAME:       "reduction"]} ins(%[[D9]] : tensor<?x?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
@@ -459,7 +475,8 @@
 
 func.func @attention(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
   %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
+  %scale = arith.constant 0.05 : f16
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
 }
 
@@ -491,10 +508,11 @@
 // TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
 // TILESIZE:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE:          %[[SCALE_Q:.+]] = linalg.generic
 // TILESIZE:          %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
 // TILESIZE:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
 // TILESIZE-SAME:       tensor<1024x32xf32>
-// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
 // TILESIZE-SAME:       tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
 // TILESIZE-SAME:       tensor<1024x32xf32>
 // TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
@@ -551,8 +569,8 @@
 // TILESIZE-SAME:     {
 // TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// TILESIZE:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// TILESIZE:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
 // TILESIZE:          linalg.yield %[[D10]] : f32
 // TILESIZE:        } -> tensor<1024x64xf32>
 // TILESIZE:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
@@ -591,7 +609,7 @@
 // TILING-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
 // TILING:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILING-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// TILING:          %[[TILED_ATTENTION:.+]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]] :
+// TILING:          %[[TILED_ATTENTION:.+]]:3 = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]], %{{[a-z0-9]+}} :
 // TILING-SAME:                                           outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] :
 // TILING-SAME:                                           -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
@@ -643,10 +661,11 @@
 // CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
 // CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK:          %[[SCALE_Q:.+]] = linalg.generic
 // CHECK:          %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
 // CHECK:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
 // CHECK-SAME:       tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
 // CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
@@ -703,8 +722,8 @@
 // CHECK-SAME:     {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
 // CHECK:          linalg.yield %[[D10]] : f32
 // CHECK:        } -> tensor<1024x64xf32>
 // CHECK:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
@@ -722,7 +741,8 @@
 
 func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
   %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
+  %scale = arith.constant 0.05 : f16
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
   return %1 : tensor<1x1024x64xf16>
 }
 
@@ -754,10 +774,11 @@
 // TILESIZE-SAME:       tensor<1x64x1024xf16> to tensor<64x32xf16>
 // TILESIZE:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE:          %[[SCALE_Q:.+]] = linalg.generic
 // TILESIZE:          %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
 // TILESIZE:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
 // TILESIZE-SAME:       tensor<1024x32xf32>
-// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
 // TILESIZE-SAME:       tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
 // TILESIZE-SAME:       tensor<1024x32xf32>
 // TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
@@ -814,8 +835,8 @@
 // TILESIZE-SAME:     {
 // TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// TILESIZE:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// TILESIZE:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
 // TILESIZE:          linalg.yield %[[D10]] : f32
 // TILESIZE:        } -> tensor<1024x64xf32>
 // TILESIZE:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
@@ -855,8 +876,8 @@
 // TILING:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILING-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
 // TILING:          %[[D9:.+]]:3 = iree_linalg_ext.attention {transpose_v = true} ins(%[[EXTRACTED_SLICE_3]],
-// TILING-SAME:       %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]] : tensor<1024x64xf16>, tensor<1024x64xf16>,
-// TILING-SAME:       tensor<64x1024xf16>) outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] : tensor<1024x64xf32>, tensor<1024xf32>,
+// TILING-SAME:       %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]], %{{.+}} : tensor<1024x64xf16>, tensor<1024x64xf16>,
+// TILING-SAME:       tensor<64x1024xf16>, f16) outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] : tensor<1024x64xf32>, tensor<1024xf32>,
 // TILING-SAME:       tensor<1024xf32>) -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:          scf.yield %[[D9]]#[[D0:.+]], %[[D9]]#[[D1:.+]], %[[D9]]#[[D2:.+]] : tensor<1024x64xf32>,
 // TILING-SAME:       tensor<1024xf32>, tensor<1024xf32>
@@ -906,10 +927,11 @@
 // CHECK-SAME:       tensor<1x64x1024xf16> to tensor<64x1024xf16>
 // CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK:          %[[SCALE_Q:.+]] = linalg.generic
 // CHECK:          %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
 // CHECK:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
 // CHECK-SAME:       tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
 // CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
@@ -966,8 +988,8 @@
 // CHECK-SAME:     {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
 // CHECK:          linalg.yield %[[D10]] : f32
 // CHECK:        } -> tensor<1024x64xf32>
 // CHECK:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir
index cbb2ded..50fb804 100644
--- a/tests/e2e/linalg_ext_ops/attention.mlir
+++ b/tests/e2e/linalg_ext_ops/attention.mlir
@@ -3,8 +3,9 @@
   %query = util.unfoldable_constant dense<1.0> : tensor<1x4x4xf32>
   %key = util.unfoldable_constant dense<0.5> : tensor<1x4x4xf32>
   %value = util.unfoldable_constant dense<2.0> : tensor<1x4x4xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<1x4x4xf32>,
-        tensor<1x4x4xf32>, tensor<1x4x4xf32>) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x4x4xf32>,
+        tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
   check.expect_almost_eq_const(
       %1,
       dense<[[[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]]> : tensor<1x4x4xf32>
diff --git a/tests/transform_dialect/cpu/attention.mlir b/tests/transform_dialect/cpu/attention.mlir
index b103ba6..51cbd8e 100644
--- a/tests/transform_dialect/cpu/attention.mlir
+++ b/tests/transform_dialect/cpu/attention.mlir
@@ -3,8 +3,9 @@
   %query = util.unfoldable_constant dense<1.0> : tensor<1x4x4xf32>
   %key = util.unfoldable_constant dense<0.5> : tensor<1x4x4xf32>
   %value = util.unfoldable_constant dense<2.0> : tensor<1x4x4xf32>
-  %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<1x4x4xf32>,
-        tensor<1x4x4xf32>, tensor<1x4x4xf32>) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+  %scale = arith.constant 1.0 : f32
+  %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x4x4xf32>,
+        tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
   return %1 : tensor<1x4x4xf32>
 }
 
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index 1bfb3d3..d768936 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -22,9 +22,10 @@
     %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %blocked_attention = transform.iree.tile_attention %attention4 :
       (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %scale_acc, %second_matmul
+    %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %scale_acc, %second_matmul
         = transform.iree.decompose_tiled_attention %blocked_attention :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
+                              !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 
     // Vectorize function
     // ==========================================