[LinalgExt] Use f32 for accumulation for online_attention (#18456)
This patch makes the ConvertAttentionToOnlineAttention conversion use
f32 for accumulation for online_attention. This removes extra
extf/truncfs inside the online attention loop which cause regressions.
Also adds tests for the conversion pass and a folding fix.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 1400199..5313a0e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -715,7 +715,7 @@
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf16>)
+// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf32>)
// CHECK-COUNT-48: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
@@ -767,6 +767,6 @@
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf16>)
+// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>)
// CHECK-COUNT-96: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ca06c4f..6454c5a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -104,8 +104,7 @@
}
def ConvertAttentionToOnlineAttentionPass :
- InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
- "mlir::FunctionOpInterface"> {
+ Pass<"iree-linalg-ext-convert-attention-to-online-attention", ""> {
let summary = "Converts attention op to online_attention op";
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index e69c513..d68a2eb 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -375,32 +375,37 @@
}
AffineMap sumMap = maxMap;
- SmallVector<Range> sizes = attnOp.getIterationDomain(rewriter);
+ AffineMap accMap = attnOp.getOutputMap();
+
+ SmallVector<Range> domain = attnOp.getIterationDomain(rewriter);
// Create fill for acc, max and sum.
// TODO: Acc should not need a fill. The attention op should get a filled
// input instead of an empty input.
- Value zeroAcc = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType()));
- Value accFill =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{zeroAcc}, attnOp.getOutput())
- .result();
+ SmallVector<OpFoldResult> sizes =
+ llvm::map_to_vector(domain, [](Range x) { return x.size; });
+ SmallVector<OpFoldResult> accSize =
+ applyPermutationMap<OpFoldResult>(accMap, sizes);
SmallVector<OpFoldResult> rowRedSize =
- llvm::map_to_vector(sizes, [](Range x) { return x.size; });
- rowRedSize = applyPermutationMap<OpFoldResult>(maxMap, rowRedSize);
+ applyPermutationMap<OpFoldResult>(maxMap, sizes);
Type f32Type = rewriter.getF32Type();
+ Value acc = rewriter.create<tensor::EmptyOp>(loc, accSize, f32Type);
Value rowRedEmpty =
rewriter.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
+ Value accInit =
+ arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, rewriter,
+ loc, /*useOnlyFiniteValue=*/true);
Value maxInit =
arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter,
loc, /*useOnlyFiniteValue=*/true);
Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type,
rewriter, loc);
+ Value accFill = rewriter.create<linalg::FillOp>(loc, ValueRange{accInit}, acc)
+ .getResult(0);
Value maxFill =
rewriter.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
.getResult(0);
@@ -427,21 +432,24 @@
// Compress the indexing maps.
SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{sumMap, attnOp.getOutputMap()});
+ compressUnusedDims(SmallVector<AffineMap>{sumMap, accMap, accMap});
SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, x.getType(), sum, x, compressedMaps, iteratorTypes,
+ loc, attnOp.getOutputType(), ValueRange{sum, x}, attnOp.getOutput(),
+ compressedMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value one = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(args[0].getType(), 1.0));
Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
- // Convert sum to the same datatype as x.
- reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(),
- /*isUnsignedCast=*/false);
+ // Both sum and x are in fp32, as created earlier, so we only need
+ // to cast after the mul.
Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
+ // Cast result to the required type by attention output.
+ result = convertScalarToDtype(b, loc, result, args[2].getType(),
+ /*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
@@ -465,7 +473,7 @@
void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
- getOperation().walk([&](AttentionOp attnOp) {
+ getOperation()->walk([&](AttentionOp attnOp) {
SmallVector<Operation *> ops;
convertToOnlineAttention(attnOp, ops, rewriter);
});
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 81d2e32..afa3578 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -1765,7 +1765,7 @@
continue;
}
dimsFound[pos] = true;
- loopBounds[pos].size = getDimValue(b, loc, val, idx);
+ loopBounds[pos].size = getDim(b, loc, val, idx);
}
};
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index 4b9a377..aff4504 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
"conv2d_to_im2col.mlir",
"conv2d_to_winograd.mlir",
"convert_to_loops.mlir",
+ "convert_to_online_attention.mlir",
"decompose_attention.mlir",
"decompose_im2col.mlir",
"decompose_online_attention.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index e89312a..47108d2 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"conv2d_to_im2col.mlir"
"conv2d_to_winograd.mlir"
"convert_to_loops.mlir"
+ "convert_to_online_attention.mlir"
"decompose_attention.mlir"
"decompose_im2col.mlir"
"decompose_online_attention.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
new file mode 100644
index 0000000..7eb4c0a
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt --split-input-file --iree-linalg-ext-convert-attention-to-online-attention %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>)
+ -> tensor<2x10x4096x128xf16> {
+ %scale = arith.constant 0.125 : f16
+ %acc = tensor.empty() : tensor<2x10x4096x128xf16>
+ %out = iree_linalg_ext.attention
+ {indexing_maps = [#map, #map1, #map2, #map3]}
+ ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16)
+ outs(%acc : tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16>
+ func.return %out : tensor<2x10x4096x128xf16>
+}
+
+// CHECK-LABEL: func.func @attention
+// CHECK-SAME: %[[Q:.+]]: tensor<2x10x4096x128xf16>, %[[K:.+]]: tensor<2x10x4096x128xf16>, %[[V:.+]]: tensor<2x10x4096x128xf16>
+// CHECK-DAG: %[[ACC_INIT:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[MAX_INIT:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG: %[[SUM_INIT:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[ACC_FILL:.+]] = linalg.fill ins(%[[ACC_INIT]]
+// CHECK-DAG: %[[MAX_FILL:.+]] = linalg.fill ins(%[[MAX_INIT]]
+// CHECK-DAG: %[[SUM_FILL:.+]] = linalg.fill ins(%[[SUM_INIT]]
+// CHECK: %[[OUT:.+]]:3 = iree_linalg_ext.online_attention
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]]
+// CHECK-SAME: outs(%[[ACC_FILL]], %[[MAX_FILL]], %[[SUM_FILL]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0
+// CHECK: arith.divf
+// CHECK: arith.mulf
+// CHECK: arith.truncf
+// CHECK: linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index c471937..df46f9f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -24,44 +24,55 @@
return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
}
-// We just want to check if we are using the correct algorithm.
+// We just want to check if we are using the correct algorithm and the
+// correct number of extf/truncfs are emitted.
// CHECK-LABEL: @attention_f16
// Q = Q * scale
// CHECK: linalg.generic
// CHECK: arith.mulf
// S = Q @ K
// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
// newMax = max(oldMax, rowMax(S))
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.maximumf
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowSum(P)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.addf
// CHECK: linalg.yield
// newAcc = norm * oldAcc
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// newAcc = P @ V + newAcc
// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
@@ -81,11 +92,11 @@
%max: tensor<192x1024xf32>,
%sum: tensor<192x1024xf32>)
-> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
- %scale = arith.constant 1.0 : f16
+ %scale = arith.constant 1.0 : f32
%out:3 = iree_linalg_ext.online_attention
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16)
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -102,40 +113,49 @@
// CHECK: linalg.yield
// S = S * scale
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK-NEXT: linalg.yield
// S = S + F8_linear_offset
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.addf
// CHECK-NEXT: linalg.yield
// newMax = max(oldMax, rowMax(S))
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.maximumf
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowSum(P)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.addf
// CHECK: linalg.yield
// clamp = clamp(norm)
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.minimumf
// CHECK: arith.truncf
// newAcc = norm * oldAcc
// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// newAcc = P @ V + newAcc
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
index 8af76f0..d258dc1 100644
--- a/tests/e2e/attention/generate_e2e_attention_tests.py
+++ b/tests/e2e/attention/generate_e2e_attention_tests.py
@@ -79,7 +79,7 @@
]
if shapes_id == ShapesId.LARGE:
return [
- TestShapeAndScale(batch=2, m=1024, k1=256, k2=128, n=64, scale=1.0),
+ TestShapeAndScale(batch=2, m=1024, k1=128, k2=128, n=64, scale=1.0),
]
raise ValueError(shapes_id)