[LinalgExt] Use f32 for accumulation for online_attention (#18456)

This patch makes the ConvertAttentionToOnlineAttention conversion use
f32 for accumulation for online_attention. This removes extra
extf/truncfs inside the online attention loop which cause regressions.

Also adds tests for the conversion pass and a folding fix.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 1400199..5313a0e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -715,7 +715,7 @@
 // CHECK-SAME:    translation_info = #[[$TRANSLATION]]
 
 // CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf16>)
+// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf32>)
 // CHECK-COUNT-48:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK: scf.yield
 
@@ -767,6 +767,6 @@
 // CHECK-SAME:    translation_info = #[[$TRANSLATION]]
 
 // CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf16>)
+// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>)
 // CHECK-COUNT-96:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK: scf.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ca06c4f..6454c5a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -104,8 +104,7 @@
 }
 
 def ConvertAttentionToOnlineAttentionPass :
-    InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
-    "mlir::FunctionOpInterface"> {
+    Pass<"iree-linalg-ext-convert-attention-to-online-attention", ""> {
   let summary = "Converts attention op to online_attention op";
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index e69c513..d68a2eb 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -375,32 +375,37 @@
   }
   AffineMap sumMap = maxMap;
 
-  SmallVector<Range> sizes = attnOp.getIterationDomain(rewriter);
+  AffineMap accMap = attnOp.getOutputMap();
+
+  SmallVector<Range> domain = attnOp.getIterationDomain(rewriter);
 
   // Create fill for acc, max and sum.
   // TODO: Acc should not need a fill. The attention op should get a filled
   // input instead of an empty input.
-  Value zeroAcc = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType()));
-  Value accFill =
-      rewriter
-          .create<linalg::FillOp>(loc, ValueRange{zeroAcc}, attnOp.getOutput())
-          .result();
 
+  SmallVector<OpFoldResult> sizes =
+      llvm::map_to_vector(domain, [](Range x) { return x.size; });
+  SmallVector<OpFoldResult> accSize =
+      applyPermutationMap<OpFoldResult>(accMap, sizes);
   SmallVector<OpFoldResult> rowRedSize =
-      llvm::map_to_vector(sizes, [](Range x) { return x.size; });
-  rowRedSize = applyPermutationMap<OpFoldResult>(maxMap, rowRedSize);
+      applyPermutationMap<OpFoldResult>(maxMap, sizes);
 
   Type f32Type = rewriter.getF32Type();
+  Value acc = rewriter.create<tensor::EmptyOp>(loc, accSize, f32Type);
   Value rowRedEmpty =
       rewriter.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
 
+  Value accInit =
+      arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, rewriter,
+                              loc, /*useOnlyFiniteValue=*/true);
   Value maxInit =
       arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter,
                               loc, /*useOnlyFiniteValue=*/true);
   Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type,
                                           rewriter, loc);
 
+  Value accFill = rewriter.create<linalg::FillOp>(loc, ValueRange{accInit}, acc)
+                      .getResult(0);
   Value maxFill =
       rewriter.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
           .getResult(0);
@@ -427,21 +432,24 @@
 
   // Compress the indexing maps.
   SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{sumMap, attnOp.getOutputMap()});
+      compressUnusedDims(SmallVector<AffineMap>{sumMap, accMap, accMap});
 
   SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
                                                  utils::IteratorType::parallel);
 
   auto genericOp = rewriter.create<linalg::GenericOp>(
-      loc, x.getType(), sum, x, compressedMaps, iteratorTypes,
+      loc, attnOp.getOutputType(), ValueRange{sum, x}, attnOp.getOutput(),
+      compressedMaps, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange args) {
         Value one = b.create<arith::ConstantOp>(
             loc, b.getFloatAttr(args[0].getType(), 1.0));
         Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
-        // Convert sum to the same datatype as x.
-        reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(),
-                                          /*isUnsignedCast=*/false);
+        // Both sum and x are in fp32, as created earlier, so we only need
+        // to cast after the mul.
         Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
+        // Cast result to the required type by attention output.
+        result = convertScalarToDtype(b, loc, result, args[2].getType(),
+                                      /*isUnsignedCast=*/false);
         b.create<linalg::YieldOp>(loc, result);
       });
   ops.push_back(genericOp);
@@ -465,7 +473,7 @@
 void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
-  getOperation().walk([&](AttentionOp attnOp) {
+  getOperation()->walk([&](AttentionOp attnOp) {
     SmallVector<Operation *> ops;
     convertToOnlineAttention(attnOp, ops, rewriter);
   });
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 81d2e32..afa3578 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -1765,7 +1765,7 @@
         continue;
       }
       dimsFound[pos] = true;
-      loopBounds[pos].size = getDimValue(b, loc, val, idx);
+      loopBounds[pos].size = getDim(b, loc, val, idx);
     }
   };
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index 4b9a377..aff4504 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
             "conv2d_to_im2col.mlir",
             "conv2d_to_winograd.mlir",
             "convert_to_loops.mlir",
+            "convert_to_online_attention.mlir",
             "decompose_attention.mlir",
             "decompose_im2col.mlir",
             "decompose_online_attention.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index e89312a..47108d2 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
     "conv2d_to_im2col.mlir"
     "conv2d_to_winograd.mlir"
     "convert_to_loops.mlir"
+    "convert_to_online_attention.mlir"
     "decompose_attention.mlir"
     "decompose_im2col.mlir"
     "decompose_online_attention.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
new file mode 100644
index 0000000..7eb4c0a
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt --split-input-file --iree-linalg-ext-convert-attention-to-online-attention %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>)
+                     -> tensor<2x10x4096x128xf16> {
+  %scale = arith.constant 0.125 : f16
+  %acc = tensor.empty() : tensor<2x10x4096x128xf16>
+  %out = iree_linalg_ext.attention
+         {indexing_maps = [#map, #map1, #map2, #map3]}
+         ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16)
+         outs(%acc : tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16>
+  func.return %out : tensor<2x10x4096x128xf16>
+}
+
+// CHECK-LABEL: func.func @attention
+// CHECK-SAME: %[[Q:.+]]: tensor<2x10x4096x128xf16>, %[[K:.+]]: tensor<2x10x4096x128xf16>, %[[V:.+]]: tensor<2x10x4096x128xf16>
+// CHECK-DAG: %[[ACC_INIT:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[MAX_INIT:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG: %[[SUM_INIT:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[ACC_FILL:.+]] = linalg.fill ins(%[[ACC_INIT]]
+// CHECK-DAG: %[[MAX_FILL:.+]] = linalg.fill ins(%[[MAX_INIT]]
+// CHECK-DAG: %[[SUM_FILL:.+]] = linalg.fill ins(%[[SUM_INIT]]
+// CHECK: %[[OUT:.+]]:3 = iree_linalg_ext.online_attention
+// CHECK-SAME:         ins(%[[Q]], %[[K]], %[[V]]
+// CHECK-SAME:         outs(%[[ACC_FILL]], %[[MAX_FILL]], %[[SUM_FILL]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0
+// CHECK: arith.divf
+// CHECK: arith.mulf
+// CHECK: arith.truncf
+// CHECK: linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index c471937..df46f9f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -24,44 +24,55 @@
   return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
 }
 
-// We just want to check if we are using the correct algorithm.
+// We just want to check if we are using the correct algorithm and the
+// correct number of extf/truncfs are emitted.
 // CHECK-LABEL: @attention_f16
 // Q = Q * scale
 // CHECK: linalg.generic
 // CHECK:   arith.mulf
 // S = Q @ K
 // CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
 // CHECK:   arith.mulf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
 // newMax = max(oldMax, rowMax(S))
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.maximumf
 // CHECK:   linalg.yield
 // norm = exp2(oldMax - newMax)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.subf
 // CHECK:   math.exp2
 // CHECK:   linalg.yield
 // normSum = norm * oldSum
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.mulf
 // CHECK:   linalg.yield
 // P = exp2(S - newMax)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.subf
 // CHECK:   math.exp2
 // CHECK:   linalg.yield
 // newSum = normSum + rowSum(P)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
 // newAcc = norm * oldAcc
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.mulf
 // CHECK:   linalg.yield
 // newAcc = P @ V + newAcc
 // CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
 // CHECK:   arith.mulf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
@@ -81,11 +92,11 @@
                          %max: tensor<192x1024xf32>,
                          %sum: tensor<192x1024xf32>)
                          -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
-  %scale = arith.constant 1.0 : f16
+  %scale = arith.constant 1.0 : f32
 
   %out:3 = iree_linalg_ext.online_attention
         { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16)
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
         outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
         -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
 
@@ -102,40 +113,49 @@
 // CHECK:   linalg.yield
 // S = S * scale
 // CHECK:   linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.mulf
 // CHECK-NEXT:   linalg.yield
 // S = S + F8_linear_offset
 // CHECK:   linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.addf
 // CHECK-NEXT:   linalg.yield
 // newMax = max(oldMax, rowMax(S))
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.maximumf
 // CHECK:   linalg.yield
 // norm = exp2(oldMax - newMax)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.subf
 // CHECK:   math.exp2
 // CHECK:   linalg.yield
 // normSum = norm * oldSum
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.mulf
 // CHECK:   linalg.yield
 // P = exp2(S - newMax)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.subf
 // CHECK:   math.exp2
 // CHECK:   linalg.yield
 // newSum = normSum + rowSum(P)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
 // clamp = clamp(norm)
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.minimumf
 // CHECK:   arith.truncf
 // newAcc = norm * oldAcc
 // CHECK: linalg.generic
+// CHECK-NOT: arith.extf
 // CHECK:   arith.mulf
 // CHECK:   linalg.yield
 // newAcc = P @ V + newAcc
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
index 8af76f0..d258dc1 100644
--- a/tests/e2e/attention/generate_e2e_attention_tests.py
+++ b/tests/e2e/attention/generate_e2e_attention_tests.py
@@ -79,7 +79,7 @@
         ]
     if shapes_id == ShapesId.LARGE:
         return [
-            TestShapeAndScale(batch=2, m=1024, k1=256, k2=128, n=64, scale=1.0),
+            TestShapeAndScale(batch=2, m=1024, k1=128, k2=128, n=64, scale=1.0),
         ]
 
     raise ValueError(shapes_id)