Generalize overriding llvm func attr flags in translation info (#17365)
Previously, we were specifically querying for waves_per_eu attr in
translation info in ROCMTarget. This patch makes this general by
attaching a llvm function attribute override dictionary in the
translation info instead which can be set.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 270ca01..84c155b 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -352,19 +352,6 @@
}
}
- // Try to get waves-per-eu from the export-specific translation info in
- // cases where codegen decides to override the value.
- // Otherwise, fallback to the default option.
- int64_t wavesPerEu = 0;
- if (auto attr = func->getAttrOfType<IntegerAttr>("waves_per_eu")) {
- wavesPerEu = attr.getValue().getSExtValue();
- }
- if (wavesPerEu == 0) {
- if (std::optional<IntegerAttr> attr =
- getConfigIntegerAttr(targetAttr, "waves_per_eu"))
- wavesPerEu = attr->getValue().getSExtValue();
- }
-
// For GPU kernels,
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
@@ -374,12 +361,30 @@
llvmFunc->addFnAttr(
"amdgpu-flat-work-group-size",
(llvm::Twine("1, ") + llvm::Twine(flatWgSize)).str());
- if (wavesPerEu > 0) {
- llvmFunc->addFnAttr("amdgpu-waves-per-eu",
- std::to_string(wavesPerEu));
- }
if (targetArch.starts_with("gfx9"))
addPreloadKernArgHint(llvmFunc);
+
+ // Set the amdgpu-waves-per-eu flag from config if given.
+ if (std::optional<IntegerAttr> attr =
+ getConfigIntegerAttr(targetAttr, "waves_per_eu")) {
+ llvmFunc->addFnAttr("amdgpu-waves-per-eu",
+ std::to_string(attr->getValue().getSExtValue()));
+ }
+
+ // Override flags as given by target func attrs.
+ if (auto funcAttrs =
+ func->getAttrOfType<DictionaryAttr>("llvm_func_attrs")) {
+ for (NamedAttribute funcAttr : funcAttrs) {
+ auto value = dyn_cast<StringAttr>(funcAttr.getValue());
+ if (!value) {
+ return variantOp->emitError("llvm_func_attrs attribute must be "
+ "adictionary of strings. Attribute " +
+ llvm::Twine(funcAttr.getName()) +
+ " is not a StringAttr.");
+ }
+ llvmFunc->addFnAttr(funcAttr.getName(), value.getValue());
+ }
+ }
}
std::unique_ptr<llvm::TargetMachine> targetMachine;
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index 04cf6b2..a5d15cf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -61,18 +61,19 @@
return subgroupSize;
}
-/// Helper function to retrieve the waves-per-eu value from translation info.
-static std::optional<int64_t>
-getWavesPerEu(IREE::Codegen::TranslationInfoAttr translationInfo) {
+/// Helper function to retrieve the target-func-attrs value from translation
+/// info.
+static DictionaryAttr
+getTargetFuncAttrs(IREE::Codegen::TranslationInfoAttr translationInfo) {
auto translationConfig = translationInfo.getConfiguration();
if (!translationConfig) {
- return std::nullopt;
+ return nullptr;
}
- auto attr = translationConfig.getAs<IntegerAttr>("waves_per_eu");
+ auto attr = translationConfig.getAs<DictionaryAttr>("llvm_func_attrs");
if (!attr) {
- return std::nullopt;
+ return nullptr;
}
- return attr.getValue().getSExtValue();
+ return attr;
}
void ReconcileTranslationInfoPass::runOnOperation() {
@@ -85,7 +86,6 @@
return signalPassFailure();
}
auto exportOp = *exportOps.begin();
- MLIRContext *context = &getContext();
Builder builder(&getContext());
SmallVector<IREE::Codegen::TranslationInfoAttr> translationInfos;
@@ -96,14 +96,13 @@
}
translationInfos.push_back(translationInfo);
- // The following is moving the waves-per-eu specification from
+ // The following is moving the target-func-attrs specification from
// translation info into the func-like op. This is not the best
// place to do this, but the intent is after this pass all the
// lowering configs and translation infos will be deleted.
- std::optional<int64_t> wavesPerEu = getWavesPerEu(translationInfo);
- if (wavesPerEu) {
- funcOp->setAttr("waves_per_eu", IntegerAttr::get(IndexType::get(context),
- wavesPerEu.value()));
+ DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
+ if (targetFuncAttrs) {
+ funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
}
});
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
index 5f53d4b..50a28e3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
@@ -128,19 +128,19 @@
// -----
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>]>]>
-hal.executable private @waves_per_eu {
- hal.executable.variant public @waves_per_eu target(#hal.executable.target<"", "", {}>) {
+hal.executable private @llvm_func_attrs {
+ hal.executable.variant public @llvm_func_attrs target(#hal.executable.target<"", "", {}>) {
hal.executable.export public @entry_point layout(#pipeline_layout)
builtin.module {
- func.func @fn1() attributes {translation_info = #iree_codegen.translation_info<None, {waves_per_eu = 2}>} {
+ func.func @fn1() attributes {translation_info = #iree_codegen.translation_info<None, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>} {
return
}
- func.func @fn2() attributes {translation_info = #iree_codegen.translation_info<None, {waves_per_eu = 4}>} {
+ func.func @fn2() attributes {translation_info = #iree_codegen.translation_info<None, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}>} {
return
}
}
}
}
-// CHECK-LABEL: hal.executable private @waves_per_eu
-// CHECK: func.func @fn1() attributes {waves_per_eu = 2 : index}
-// CHECK: func.func @fn2() attributes {waves_per_eu = 4 : index}
+// CHECK-LABEL: hal.executable private @llvm_func_attrs
+// CHECK: func.func @fn1() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}
+// CHECK: func.func @fn2() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}