[LinalgExt] Generalize attribute setting for attention decomposition (#18780)
This PR teaches attention decomposition to set attributes for attention
matmuls by passing attribute dictionaries to
iree_linalg_ext.online_attention operation. This allows us to further
control codegen of matmuls (generally the root operations) after
decomposition (for example, setting lowering config on the decompose
matmuls).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 4b64cda..0d9c7f9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -825,7 +825,25 @@
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
- auto configDict = DictionaryAttr::get(context, attrs);
+ SmallVector<NamedAttribute, 2> qkAttrs;
+ SmallVector<NamedAttribute, 2> pvAttrs;
+
+ qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
+ pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));
+
+ auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
+ auto pvAttrDict = b.getDictionaryAttr(pvAttrs);
+
+ SmallVector<NamedAttribute, 2> decompositionConfig;
+ decompositionConfig.emplace_back(
+ b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict));
+ decompositionConfig.emplace_back(
+ b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict));
+
+ DictionaryAttr decompositionConfigDict =
+ b.getDictionaryAttr(decompositionConfig);
+
+ auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
// Attach the MMA schedule as an attribute to the entry point export function
@@ -843,6 +861,9 @@
auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
+ // Set attention decomposition control config.
+ op.setDecompositionConfigAttr(decompositionConfigDict);
+
return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
workgroupSize, targetSubgroupSize, pipelineConfig);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index d21faf8..4334e79 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -688,7 +688,9 @@
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
- lowering_config = #config}
+ lowering_config = #config,
+ decomposition_config = {qk_attrs = {attention_qk_matmul},
+ pv_attrs = {attention_pv_matmul}}}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
@@ -753,7 +755,15 @@
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
- %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
+ %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
+ lowering_config = #config,
+ decomposition_config = {qk_attrs = {attention_qk_matmul},
+ pv_attrs = {attention_pv_matmul}}}
+ ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
@@ -811,7 +821,15 @@
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
- %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
+ %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
+ lowering_config = #config,
+ decomposition_config = {qk_attrs = {attention_qk_matmul},
+ pv_attrs = {attention_pv_matmul}}}
+ ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index 02d1e71..204ae35 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -313,6 +313,13 @@
Value oldMax = getMax();
Value oldSum = getSum();
Type elementType = getElementTypeOrSelf(getOutput().getType());
+ DictionaryAttr config = getDecompositionConfigAttr();
+
+ DictionaryAttr qkAttrs, pvAttrs;
+ if (config) {
+ qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
+ pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
+ }
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
@@ -368,10 +375,9 @@
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
-
- // TODO: We shouldn't be relying on such attributes. We need a better
- // mechanism to identify attention matmuls.
- s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr());
+ if (qkAttrs) {
+ s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
+ }
s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
@@ -448,9 +454,9 @@
// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
- // TODO: We shouldn't be relying on such attributes. We need a better
- // mechanism to identify attention matmuls.
- newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr());
+ if (pvAttrs) {
+ newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs);
+ }
return SmallVector<Value>{newAcc, newMax, newSum};
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 6abaec4..77a2d51 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1213,7 +1213,7 @@
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output,
- indexingMaps);
+ indexingMaps, DictionaryAttr());
}
LogicalResult AttentionOp::verify() {
@@ -1388,7 +1388,7 @@
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
- max, sum, indexingMaps);
+ max, sum, indexingMaps, DictionaryAttr());
}
LogicalResult OnlineAttentionOp::verify() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index e097ce5..329c79c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -501,7 +501,8 @@
AnyFloat:$scale,
Optional<AnyShaped>:$mask,
AnyShaped:$output,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);
@@ -558,6 +559,12 @@
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
+
+ /* Decomposition control attributes */
+
+ // Attributes to set on QK and PV matmul after decomposition.
+ static StringRef getQKAttrStr() { return "qk_attrs"; }
+ static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}
@@ -612,7 +619,8 @@
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);
@@ -679,6 +687,12 @@
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
+
+ /* Decomposition control attributes */
+
+ // Attributes to set on QK and PV matmul after decomposition.
+ static StringRef getQKAttrStr() { return "qk_attrs"; }
+ static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 0aa3a37..d9a4873 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -106,7 +106,8 @@
loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
mask, accFill, maxFill, sumFill,
- rewriter.getAffineMapArrayAttr(indexingMaps));
+ rewriter.getAffineMapArrayAttr(indexingMaps),
+ attnOp.getDecompositionConfigAttr());
rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(),
onlineAttn.getRegion().begin());