[LinalgExt] Implement AggregateOpInterface for AttentionOp (#18890)
- Adds AggregateOpInterface for AttentionOp
- Move all aggregate interface tests to IR/test/decompose_aggregate_op
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index 204ae35..7fc985b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -299,6 +299,212 @@
}
//===----------------------------------------------------------------------===//
+// Attention Helpers
+//===----------------------------------------------------------------------===//
+
+Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
+ Value key, Value scale, std::optional<Value> mask,
+ AffineMap qMap, AffineMap kMap, AffineMap sMap,
+ std::optional<AffineMap> maskMap,
+ SmallVector<OpFoldResult> iterationDomain,
+ Type sElementType, Region &elementwiseRegion,
+ DictionaryAttr qkAttrs, bool lowPrecision) {
+ MLIRContext *ctx = b.getContext();
+ // Since we use exp2 for attention instead of the original exp, we have to
+ // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
+ // have better support for exp2 (we verified that we gain some speedup on
+ // some GPUs).
+ Value log2e = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(scale.getType(), M_LOG2E));
+ scale = b.create<arith::MulFOp>(loc, scale, log2e);
+
+ auto qETy = getElementTypeOrSelf(query.getType());
+
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
+ /*symbolCount=*/0, ctx);
+
+ // In the original algorithm, the scaling is done after the softmax:
+ // softmax(Q @ K.T * scale) @ V
+ //
+ // But, it is mathematically equivalent to do it on Q first and then multiply
+ // it by K.T. This just allows us to do the scaling once, instead of each
+ // iteration of the loop. This is only valid for f16 or f32 types as f8
+ // is extremely limited on its dynamic range therefore this would
+ // significantly affect numerics.
+ if (!lowPrecision) {
+ query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
+ query, scale);
+ }
+
+ // ---- QK Matmul ----
+
+ // Get sizes for S.
+ SmallVector<OpFoldResult> sSizes;
+ for (AffineExpr dimExpr : sMap.getResults()) {
+ int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ sSizes.push_back(iterationDomain[dim]);
+ }
+
+ // S = Q @ K
+ // SMap = QMap @ KMap
+ Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, sElementType);
+ Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(sElementType));
+ Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
+
+ s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s);
+ if (qkAttrs) {
+ s.getDefiningOp()->setAttrs(qkAttrs);
+ }
+
+ s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s);
+
+ if (lowPrecision) {
+ // For low bit-depth types we perform post Q @ K scaling. This is to avoid
+ // losing numerical precision due to the low dynamic range of fp8 types when
+ // pre applying the sclaing.
+ AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
+ /*symbolCount=*/0, ctx);
+ s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
+ scale);
+
+ // If we need to truncate to fp8 post softmax we apply a scaling to use the
+ // full fp8 range. We can do this with a offset as post `exp2` this equates
+ // to multiplying by a static value. We are able to do this as `max` and
+ // `sum` are scaled by the same value so the end result is the same.
+ auto fpTy = cast<FloatType>(qETy);
+ double mx =
+ APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
+ .convertToDouble();
+ Value offset = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx));
+ s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
+ offset);
+ }
+
+ // S += mask
+ if (mask != nullptr) {
+ s = applyMask(b, loc, sMap, *maskMap, s, mask.value());
+ }
+
+ return s;
+}
+
+//===----------------------------------------------------------------------===//
+// AttentionOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
+ Location loc = getLoc();
+ Value query = getQuery();
+ Value key = getKey();
+ Value value = getValue();
+ std::optional<Value> mask = getMask();
+ DictionaryAttr config = getDecompositionConfigAttr();
+
+ DictionaryAttr qkAttrs, pvAttrs;
+ if (config) {
+ qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
+ pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
+ }
+ Value output = getOutput();
+
+ FailureOr<AttentionOpDetail> maybeOpInfo =
+ AttentionOpDetail::get(getIndexingMapsArray());
+ assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
+ AttentionOpDetail opInfo = maybeOpInfo.value();
+
+ SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
+ getIterationDomain(b), [](Range x) { return x.size; });
+
+ AffineMap qMap = getQueryMap();
+ AffineMap kMap = getKeyMap();
+ AffineMap sMap = opInfo.getSMap();
+
+ auto qETy = getElementTypeOrSelf(query.getType());
+ bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
+
+ // We compute output of first matmul in f32.
+ Type f32Type = b.getF32Type();
+
+ // ---- QK Matmul + elementwise math ----
+ Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
+ kMap, sMap, getMaskMap(), sizes, f32Type,
+ getRegion(), qkAttrs, lowPrecision);
+
+ // ---- Softmax ----
+
+ AffineMap accMap = getOutputMap();
+
+ llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false);
+ for (auto dim : opInfo.getK2Dims()) {
+ projectedK2Dims.set(dim);
+ }
+
+ AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults();
+ AffineMap sumMap = maxMap;
+
+ SmallVector<OpFoldResult> rowRedSize =
+ applyPermutationMap<OpFoldResult>(maxMap, sizes);
+
+ Value rowRedEmpty = b.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
+
+ Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf,
+ getElementTypeOrSelf(output), b, loc,
+ /*useOnlyFiniteValue=*/true);
+ Value maxInit =
+ arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc,
+ /*useOnlyFiniteValue=*/true);
+ Value sumInit =
+ arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc);
+
+ Value accFill =
+ b.create<linalg::FillOp>(loc, ValueRange{accInit}, output).getResult(0);
+ Value maxFill =
+ b.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
+ .getResult(0);
+ Value sumFill =
+ b.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
+ .getResult(0);
+
+ // max = rowMax(S)
+ Value max = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, maxFill);
+
+ // P = exp2(S - max)
+ AffineMap pMap = sMap;
+ Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);
+
+ // sum = rowSum(P)
+ Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
+
+ // P = P / sum
+ p = elementwiseValueInPlace<arith::DivFOp>(b, loc, pMap, sumMap, p, sum);
+
+ // ---- Scale and truncate LHS to match RHS ----
+ SmallVector<OpFoldResult> sSizes;
+ for (AffineExpr dimExpr : sMap.getResults()) {
+ int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ sSizes.push_back(sizes[dim]);
+ }
+
+ auto pETy = getElementTypeOrSelf(p.getType());
+ auto vETy = getElementTypeOrSelf(value.getType());
+ if (pETy != vETy && isa<FloatType>(vETy)) {
+ Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
+ p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
+ }
+
+ // result = P @ V + acc
+ Value result =
+ computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill);
+ if (pvAttrs) {
+ result.getDefiningOp()->setAttrs(pvAttrs);
+ }
+
+ return SmallVector<Value>{result};
+}
+
+//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//
@@ -329,87 +535,17 @@
SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
getIterationDomain(b), [](Range x) { return x.size; });
- // Since we use exp2 for attention instead of the original exp, we have to
- // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
- // have better support for exp2 (we verified that we gain some speedup on
- // some GPUs).
- Value scale = getScale();
- Value log2e = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(scale.getType(), M_LOG2E));
- scale = b.create<arith::MulFOp>(loc, scale, log2e);
+ AffineMap qMap = getQueryMap();
+ AffineMap kMap = getKeyMap();
+ AffineMap sMap = opInfo.getSMap();
auto qETy = getElementTypeOrSelf(query.getType());
- auto vETy = getElementTypeOrSelf(value.getType());
-
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(),
- /*symbolCount=*/0, getContext());
-
- // In the original algorithm, the scaling is done after the softmax:
- // softmax(Q @ K.T * scale) @ V
- //
- // But, it is mathematically equivalent to do it on Q first and then multiply
- // it by K.T. This just allows us to do the scaling once, instead of each
- // iteration of the loop. This is only valid for f16 or f32 types as f8
- // is extremely limited on its dynamic range therefore this would
- // significantly affect numerics.
- if (qETy.getIntOrFloatBitWidth() > 8) {
- AffineMap qMap = getQueryMap();
- query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
- query, scale);
- }
-
- // ---- Matmul 1 ----
-
- // Get sizes for S.
- AffineMap sMap = opInfo.getSMap();
- SmallVector<OpFoldResult> sSizes;
- for (AffineExpr dimExpr : sMap.getResults()) {
- int dim = cast<AffineDimExpr>(dimExpr).getPosition();
- sSizes.push_back(sizes[dim]);
- }
-
- // S = Q @ K
- // SMap = QMap @ KMap
- Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
- Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
- Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
-
- s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
- if (qkAttrs) {
- s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
- }
-
- s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
-
bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
- if (lowPrecision) {
- // For low bit-depth types we perform post Q @ K scaling. This is to avoid
- // losing numerical precision due to the low dynamic range of fp8 types when
- // pre applying the sclaing.
- AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
- /*symbolCount=*/0, getContext());
- s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
- scale);
- // If we need to truncate to fp8 post softmax we apply a scaling to use the
- // full fp8 range. We can do this with a offset as post `exp2` this equates
- // to multiplying by a static value. We are able to do this as `max` and
- // `sum` are scaled by the same value so the end result is the same.
- auto fpTy = cast<FloatType>(qETy);
- double mx =
- APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
- .convertToDouble();
- Value offset = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx));
- s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
- offset);
- }
-
- // S += mask
- if (mask != nullptr) {
- s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value());
- }
+ // ---- QK Matmul + elementwise math ----
+ Value s = computeQKAndElementwise(
+ loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
+ sizes, elementType, getRegion(), qkAttrs, lowPrecision);
// TODO: This decomposition should be in a seperate op called
// "online softmax".
@@ -441,7 +577,14 @@
AffineMap accMap = getOutputMap();
// ---- Scale and truncate LHS to match RHS ----
+ SmallVector<OpFoldResult> sSizes;
+ for (AffineExpr dimExpr : sMap.getResults()) {
+ int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ sSizes.push_back(sizes[dim]);
+ }
+
auto pETy = getElementTypeOrSelf(p.getType());
+ auto vETy = getElementTypeOrSelf(value.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 329c79c..3b46114 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -475,6 +475,7 @@
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
index ade6135..aff5921 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
@@ -17,6 +17,7 @@
srcs = enforce_glob(
[
"canonicalize.mlir",
+ "decompose_aggregate_op.mlir",
"invalid.mlir",
"roundtrip.mlir",
],
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
index f6d6730..36bdf43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"canonicalize.mlir"
+ "decompose_aggregate_op.mlir"
"invalid.mlir"
"roundtrip.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir
new file mode 100644
index 0000000..fae9e5b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir
@@ -0,0 +1,418 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s
+
+// Spec to decompose custom op.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>,
+ %rhs1 : tensor<?x?xf32>, %rhs2 : tensor<?x?xf32>, %scalar : f32,
+ %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>)
+ -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
+ %0:2 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
+ affine_map<(d0, d1)[s0, s1] -> ()>,
+ affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ ins(%lhs1, %rhs1, %rhs2, %scalar
+ : tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, f32)
+ outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
+ ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
+ %s : f32, %t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
+ %0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> ()>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%1, %s : tensor<?x?xf32>, f32) outs(%1 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 :f32):
+ %3 = arith.addf %b0, %b2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %0, %2 : tensor<?x?xf32>, tensor<?x?xf32>
+ } -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
+ return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
+}
+
+// CHECK-LABEL: func @custom_op_decomposition(
+// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+// CHECK: %[[MATMUL1:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] :
+// CHECK-SAME: outs(%[[INIT1]] :
+// CHECK: %[[MATMUL2:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] :
+// CHECK-SAME: outs(%[[INIT2]] :
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] :
+// CHECK-SAME: outs(%[[MATMUL2]] :
+// CHECK: return %[[MATMUL1]], %[[GENERIC]]
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f16(%query: tensor<192x1024x64xf16>,
+ %key: tensor<192x1024x64xf16>,
+ %value: tensor<192x1024x64xf16>,
+ %output: tensor<192x1024x64xf32>)
+ -> (tensor<192x1024x64xf32>) {
+ %scale = arith.constant 1.0 : f16
+
+ %out = iree_linalg_ext.attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO] }
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
+ outs(%output : tensor<192x1024x64xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score: f32
+ }
+ -> tensor<192x1024x64xf32>
+
+ return %out : tensor<192x1024x64xf32>
+}
+
+// We just want to check if we are using the correct algorithm
+// CHECK-LABEL: @attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// max = rowMax(S)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// P = exp2(S - max)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// sum = rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// P = P /= sum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.divf
+// CHECK: linalg.yield
+// truncf P : f32 to f16
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.truncf
+// CHECK: linalg.yield
+// newAcc = P @ V
+// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f16(%query: tensor<192x1024x64xf16>,
+ %key: tensor<192x1024x64xf16>,
+ %value: tensor<192x1024x64xf16>,
+ %output: tensor<192x1024x64xf32>,
+ %max: tensor<192x1024xf32>,
+ %sum: tensor<192x1024xf32>)
+ -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+ %scale = arith.constant 1.0 : f16
+
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
+ outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score: f32
+ }
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// We just want to check if we are using the correct algorithm and the
+// correct number of extf/truncfs are emitted.
+// CHECK-LABEL: @online_attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// newSum = normSum + rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK: arith.extf
+// CHECK: arith.extf
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+ %key: tensor<192x1024x64xf8E4M3FNUZ>,
+ %value: tensor<192x1024x64xf8E4M3FNUZ>,
+ %output: tensor<192x1024x64xf32>,
+ %max: tensor<192x1024xf32>,
+ %sum: tensor<192x1024xf32>)
+ -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+ %scale = arith.constant 1.0 : f32
+
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
+ outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score: f32
+ }
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// CHECK-LABEL: @online_attention_f8
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// S = S * scale
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.mulf
+// CHECK-NEXT: linalg.yield
+// S = S + F8_linear_offset
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.addf
+// CHECK-NEXT: linalg.yield
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// newSum = normSum + rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// clamp = clamp(norm)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.minimumf
+// CHECK: arith.truncf
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+ %key: tensor<192x1024x64xf8E4M3FNUZ>,
+ %value: tensor<192x1024x64xf8E4M3FNUZ>,
+ %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
+ %output: tensor<192x1024x64xf32>,
+ %max: tensor<192x1024xf32>,
+ %sum: tensor<192x1024xf32>)
+ -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+ %scale = arith.constant 1.0 : f16
+
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
+ outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score: f32
+ }
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+// CHECK-LABEL: @online_attention_f8_masked
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// S = S * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// S = S + mask
+// CHECK: arith.addf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK: arith.addf
+// CHECK: linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index efe463a6..6ba9d5c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -20,9 +20,7 @@
"conv2d_to_winograd.mlir",
"convert_to_loops.mlir",
"convert_to_online_attention.mlir",
- "decompose_aggregate_op.mlir",
"decompose_im2col.mlir",
- "decompose_online_attention.mlir",
"decompose_winograd.mlir",
"distribution.mlir",
"pad_contraction_to_block_size.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 3288c14..a912973 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -18,9 +18,7 @@
"conv2d_to_winograd.mlir"
"convert_to_loops.mlir"
"convert_to_online_attention.mlir"
- "decompose_aggregate_op.mlir"
"decompose_im2col.mlir"
- "decompose_online_attention.mlir"
"decompose_winograd.mlir"
"distribution.mlir"
"pad_contraction_to_block_size.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir
deleted file mode 100644
index 80b0b7a..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir
+++ /dev/null
@@ -1,62 +0,0 @@
-// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s
-
-func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>,
- %rhs1 : tensor<?x?xf32>, %rhs2 : tensor<?x?xf32>, %scalar : f32,
- %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>)
- -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
- %0:2 = iree_linalg_ext.custom_op {
- indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
- affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
- affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
- affine_map<(d0, d1)[s0, s1] -> ()>,
- affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
- iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
- #iree_linalg_ext.iterator_type<parallel>]}
- ins(%lhs1, %rhs1, %rhs2, %scalar
- : tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, f32)
- outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
- ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
- %s : f32, %t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
- %0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %2 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> ()>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%1, %s : tensor<?x?xf32>, f32) outs(%1 : tensor<?x?xf32>) {
- ^bb0(%b0 : f32, %b1 : f32, %b2 :f32):
- %3 = arith.addf %b0, %b2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x?xf32>
- iree_linalg_ext.yield %0, %2 : tensor<?x?xf32>, tensor<?x?xf32>
- } -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
- return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
-}
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op
- transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
- transform.yield
- }
-}
-// CHECK-LABEL: func @custom_op_decomposition(
-// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32
-// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-// CHECK: %[[MATMUL1:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] :
-// CHECK-SAME: outs(%[[INIT1]] :
-// CHECK: %[[MATMUL2:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] :
-// CHECK-SAME: outs(%[[INIT2]] :
-// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] :
-// CHECK-SAME: outs(%[[MATMUL2]] :
-// CHECK: return %[[MATMUL1]], %[[GENERIC]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
deleted file mode 100644
index 19bd2ca..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ /dev/null
@@ -1,242 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f16(%query: tensor<192x1024x64xf16>,
- %key: tensor<192x1024x64xf16>,
- %value: tensor<192x1024x64xf16>,
- %output: tensor<192x1024x64xf32>,
- %max: tensor<192x1024xf32>,
- %sum: tensor<192x1024xf32>)
- -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
- %scale = arith.constant 1.0 : f16
-
- %out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
- outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
- ^bb0(%score: f32):
- iree_linalg_ext.yield %score: f32
- }
- -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
- return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// We just want to check if we are using the correct algorithm and the
-// correct number of extf/truncfs are emitted.
-// CHECK-LABEL: @attention_f16
-// Q = Q * scale
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK: arith.extf
-// CHECK: arith.extf
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.maximumf
-// CHECK: linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// newSum = normSum + rowSum(P)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK: arith.extf
-// CHECK: arith.extf
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
- %key: tensor<192x1024x64xf8E4M3FNUZ>,
- %value: tensor<192x1024x64xf8E4M3FNUZ>,
- %output: tensor<192x1024x64xf32>,
- %max: tensor<192x1024xf32>,
- %sum: tensor<192x1024xf32>)
- -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
- %scale = arith.constant 1.0 : f32
-
- %out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
- outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
- ^bb0(%score: f32):
- iree_linalg_ext.yield %score: f32
- }
- -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
- return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// CHECK-LABEL: @attention_f8
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// S = S * scale
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.mulf
-// CHECK-NEXT: linalg.yield
-// S = S + F8_linear_offset
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.addf
-// CHECK-NEXT: linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.maximumf
-// CHECK: linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// newSum = normSum + rowSum(P)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// clamp = clamp(norm)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.minimumf
-// CHECK: arith.truncf
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
-// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
- %key: tensor<192x1024x64xf8E4M3FNUZ>,
- %value: tensor<192x1024x64xf8E4M3FNUZ>,
- %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
- %output: tensor<192x1024x64xf32>,
- %max: tensor<192x1024xf32>,
- %sum: tensor<192x1024xf32>)
- -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
- %scale = arith.constant 1.0 : f16
-
- %out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
- outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
- ^bb0(%score: f32):
- iree_linalg_ext.yield %score: f32
- }
- -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
- return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-// CHECK-LABEL: @attention_f8_masked
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// S = S * scale
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// S = S + mask
-// CHECK: arith.addf
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK: arith.maximumf
-// CHECK: linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// newSum = normSum + rowMax(P)
-// CHECK: linalg.generic
-// CHECK: arith.addf
-// CHECK: linalg.yield