[LinalgExt] Generalize attribute setting for attention decomposition (#18780)

This PR teaches attention decomposition to set attributes for attention
matmuls by passing attribute dictionaries to
iree_linalg_ext.online_attention operation. This allows us to further
control codegen of matmuls (generally the root operations) after
decomposition (for example, setting lowering config on the decompose
matmuls).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 4b64cda..0d9c7f9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -825,7 +825,25 @@
   attrs.emplace_back(StringAttr::get(context, "reduction"),
                      b.getI64ArrayAttr(reductionTileSizes));
 
-  auto configDict = DictionaryAttr::get(context, attrs);
+  SmallVector<NamedAttribute, 2> qkAttrs;
+  SmallVector<NamedAttribute, 2> pvAttrs;
+
+  qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
+  pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));
+
+  auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
+  auto pvAttrDict = b.getDictionaryAttr(pvAttrs);
+
+  SmallVector<NamedAttribute, 2> decompositionConfig;
+  decompositionConfig.emplace_back(
+      b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict));
+  decompositionConfig.emplace_back(
+      b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict));
+
+  DictionaryAttr decompositionConfigDict =
+      b.getDictionaryAttr(decompositionConfig);
+
+  auto configDict = b.getDictionaryAttr(attrs);
   auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
 
   // Attach the MMA schedule as an attribute to the entry point export function
@@ -843,6 +861,9 @@
 
   auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
 
+  // Set attention decomposition control config.
+  op.setDecompositionConfigAttr(decompositionConfigDict);
+
   return setOpConfigAndEntryPointFnTranslation(
       entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
       workgroupSize, targetSubgroupSize, pipelineConfig);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index d21faf8..4334e79 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -688,7 +688,9 @@
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
                      affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
-                     lowering_config = #config}
+                     lowering_config = #config,
+                     decomposition_config = {qk_attrs = {attention_qk_matmul},
+                                             pv_attrs = {attention_pv_matmul}}}
                      ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
                       ^bb0(%score: f32):
                         iree_linalg_ext.yield %score : f32
@@ -753,7 +755,15 @@
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
-        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
+        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
+                                                         lowering_config = #config,
+                                                         decomposition_config = {qk_attrs = {attention_qk_matmul},
+                                                                                 pv_attrs = {attention_pv_matmul}}}
+        ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
               ^bb0(%score: f32):
                 iree_linalg_ext.yield %score : f32
              } -> tensor<24x64x4608x128xf16>
@@ -811,7 +821,15 @@
         %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
-        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
+        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+                                                         affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
+                                                         lowering_config = #config,
+                                                         decomposition_config = {qk_attrs = {attention_qk_matmul},
+                                                                                 pv_attrs = {attention_pv_matmul}}}
+        ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
               ^bb0(%score: f32):
                 iree_linalg_ext.yield %score : f32
              } -> tensor<24x64x4608x128xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index 02d1e71..204ae35 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -313,6 +313,13 @@
   Value oldMax = getMax();
   Value oldSum = getSum();
   Type elementType = getElementTypeOrSelf(getOutput().getType());
+  DictionaryAttr config = getDecompositionConfigAttr();
+
+  DictionaryAttr qkAttrs, pvAttrs;
+  if (config) {
+    qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
+    pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
+  }
 
   FailureOr<AttentionOpDetail> maybeOpInfo =
       AttentionOpDetail::get(getIndexingMapsArray());
@@ -368,10 +375,9 @@
   Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
 
   s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
-
-  // TODO: We shouldn't be relying on such attributes. We need a better
-  // mechanism to identify attention matmuls.
-  s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr());
+  if (qkAttrs) {
+    s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
+  }
 
   s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
 
@@ -448,9 +454,9 @@
 
   // newAcc = P @ V + newAcc
   newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
-  // TODO: We shouldn't be relying on such attributes. We need a better
-  // mechanism to identify attention matmuls.
-  newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr());
+  if (pvAttrs) {
+    newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs);
+  }
 
   return SmallVector<Value>{newAcc, newMax, newSum};
 }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 6abaec4..77a2d51 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1213,7 +1213,7 @@
                         std::optional<Value> mask) {
   Value maskIn = mask.value_or(Value());
   build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output,
-        indexingMaps);
+        indexingMaps, DictionaryAttr());
 }
 
 LogicalResult AttentionOp::verify() {
@@ -1388,7 +1388,7 @@
                               std::optional<Value> mask) {
   Value maskIn = mask.value_or(Value());
   build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
-        max, sum, indexingMaps);
+        max, sum, indexingMaps, DictionaryAttr());
 }
 
 LogicalResult OnlineAttentionOp::verify() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index e097ce5..329c79c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -501,7 +501,8 @@
                        AnyFloat:$scale,
                        Optional<AnyShaped>:$mask,
                        AnyShaped:$output,
-                       AffineMapArrayAttr:$indexing_maps
+                       AffineMapArrayAttr:$indexing_maps,
+                       OptionalAttr<DictionaryAttr>:$decomposition_config
   );
   let regions = (region SizedRegion<1>:$region);
 
@@ -558,6 +559,12 @@
     int64_t getIterationDomainRank() {
       return getQueryMap().getNumDims();
     }
+
+    /* Decomposition control attributes */
+
+    // Attributes to set on QK and PV matmul after decomposition.
+    static StringRef getQKAttrStr() { return "qk_attrs"; }
+    static StringRef getPVAttrStr() { return "pv_attrs"; }
   }];
 }
 
@@ -612,7 +619,8 @@
                        AnyShaped:$output,
                        AnyShaped:$max,
                        AnyShaped:$sum,
-                       AffineMapArrayAttr:$indexing_maps
+                       AffineMapArrayAttr:$indexing_maps,
+                       OptionalAttr<DictionaryAttr>:$decomposition_config
   );
   let regions = (region SizedRegion<1>:$region);
 
@@ -679,6 +687,12 @@
     int64_t getIterationDomainRank() {
       return getQueryMap().getNumDims();
     }
+
+    /* Decomposition control attributes */
+
+    // Attributes to set on QK and PV matmul after decomposition.
+    static StringRef getQKAttrStr() { return "qk_attrs"; }
+    static StringRef getPVAttrStr() { return "pv_attrs"; }
   }];
 }
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 0aa3a37..d9a4873 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -106,7 +106,8 @@
       loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
       attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
       mask, accFill, maxFill, sumFill,
-      rewriter.getAffineMapArrayAttr(indexingMaps));
+      rewriter.getAffineMapArrayAttr(indexingMaps),
+      attnOp.getDecompositionConfigAttr());
 
   rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(),
                              onlineAttn.getRegion().begin());