[tuner]: add c/python binding for querying mma intrinsic (#19218)
After this PR: https://github.com/iree-org/iree/pull/19199
add Python bindings to these two utility functions to querying mma
intrinsic instructions from input module.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
index 029e672..3adbc3e 100644
--- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
+++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
@@ -66,6 +66,14 @@
MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);
+MLIR_CAPI_EXPORTED void
+ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
+ MlirOperation *executableOps);
+
+MLIR_CAPI_EXPORTED void ireeCodegenQueryMMAIntrinsics(MlirOperation op,
+ size_t *numIntrinsics,
+ uint32_t *mmaIntrinsics);
+
#ifdef __cplusplus
}
#endif
diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp
index 7ece224..dec10fb 100644
--- a/compiler/bindings/python/IREECompilerDialectsModule.cpp
+++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp
@@ -21,6 +21,33 @@
namespace py = pybind11;
using namespace mlir::python::adaptors;
+static std::vector<MlirOperation>
+ireeCodegenGetExecutableVariantOpsBinding(MlirModule module) {
+ size_t numOps = 0;
+ ireeCodegenGetExecutableVariantOps(module, &numOps, nullptr);
+ std::vector<MlirOperation> ops(numOps);
+ ireeCodegenGetExecutableVariantOps(module, &numOps, ops.data());
+
+ return ops;
+}
+
+static std::vector<py::object>
+ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) {
+ size_t numMMAs = 0;
+ ireeCodegenQueryMMAIntrinsics(op, &numMMAs, nullptr);
+ std::vector<uint32_t> mmaIntrinsics(numMMAs);
+ ireeCodegenQueryMMAIntrinsics(op, &numMMAs, mmaIntrinsics.data());
+
+ py::object mmaIntrinsicEnum =
+ py::module_::import(kGpuModuleImportPath).attr("MMAIntrinsic");
+ std::vector<py::object> mmaList(numMMAs);
+ for (size_t i = 0; i < numMMAs; ++i) {
+ mmaList[i] = mmaIntrinsicEnum(mmaIntrinsics[i]);
+ }
+
+ return mmaList;
+}
+
PYBIND11_MODULE(_ireeCompilerDialects, m) {
m.doc() = "iree-compiler dialects python extension";
@@ -326,4 +353,22 @@
"Gets an #iree_gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes);
+
+ //===-------------------------------------------------------------------===//
+ // Binding to utility function getExecutableVariantOps
+ //===-------------------------------------------------------------------===//
+
+ iree_codegen_module.def(
+ "get_executable_variant_ops", &ireeCodegenGetExecutableVariantOpsBinding,
+ "Gets the executable variant operations from a module.",
+ py::arg("module"));
+
+ //===-------------------------------------------------------------------===//
+ // Binding to utility function queryMMAIntrinsics
+ //===-------------------------------------------------------------------===//
+
+ iree_codegen_module.def(
+ "query_mma_intrinsics", &ireeCodegenQueryMMAIntrinsicsBinding,
+ "Queries the MMA intrinsics from an executable variant op.",
+ py::arg("op"));
}
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 9b6d2b8..2d67e44 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -137,6 +137,7 @@
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Codegen/Utils",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index e0ec31d..871fbf3 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -116,6 +116,7 @@
MLIRCAPIIR
MLIRIR
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+ iree::compiler::Codegen::Utils
iree::compiler::bindings::c::headers
PUBLIC
)
diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
index c295d48..82e2496 100644
--- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
@@ -10,6 +10,7 @@
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
@@ -24,6 +25,8 @@
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
+using mlir::iree_compiler::IREE::GPU::MMAIntrinsic;
+using mlir::iree_compiler::IREE::HAL::ExecutableVariantOp;
bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
MlirAttribute attr) {
@@ -149,3 +152,49 @@
parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
return parameters;
}
+
+void ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
+ MlirOperation *executableOps) {
+ assert(!mlirModuleIsNull(module) && "module cannot be nullptr");
+ assert(numOps && "numOps cannot be nullptr");
+
+ mlir::ModuleOp moduleOp = unwrap(module);
+ llvm::SmallVector<ExecutableVariantOp> executableVariantOps =
+ mlir::iree_compiler::getExecutableVariantOps(moduleOp);
+
+ if (!executableOps) {
+ *numOps = executableVariantOps.size();
+ return;
+ }
+
+ assert(
+ *numOps == executableVariantOps.size() &&
+ "*numOps must match the number of elements in the executableVariantOps");
+
+ for (size_t i = 0, e = executableVariantOps.size(); i < e; ++i) {
+ executableOps[i] = wrap(executableVariantOps[i]);
+ }
+}
+
+void ireeCodegenQueryMMAIntrinsics(MlirOperation op, size_t *numIntrinsics,
+ uint32_t *mmaIntrinsics) {
+ assert(numIntrinsics && "numIntrinsics cannot be nullptr");
+
+ mlir::Operation *mlirOp = unwrap(op);
+ auto variantOp = llvm::dyn_cast_if_present<ExecutableVariantOp>(mlirOp);
+ assert(variantOp && "operation is not a ExecutableVariantOp");
+
+ llvm::SmallVector<MMAIntrinsic> intrinsics =
+ mlir::iree_compiler::queryMMAIntrinsics(variantOp);
+ if (!mmaIntrinsics) {
+ *numIntrinsics = intrinsics.size();
+ return;
+ }
+
+ assert(*numIntrinsics == intrinsics.size() &&
+ "*numIntrinsics must match the number of elements in the intrinsics");
+
+ for (size_t i = 0, e = intrinsics.size(); i < e; ++i) {
+ mmaIntrinsics[i] = static_cast<uint32_t>(intrinsics[i]);
+ }
+}
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index ffb8f08..7f1b550 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -24,6 +24,8 @@
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
+extern void ireeCodegenGetExecutableVariantOps();
+extern void ireeCodegenQueryMMAIntrinsics();
extern void ireeCodegenTranslationInfoAttrGet();
extern void ireeCodegenTranslationInfoAttrGetParameters();
extern void ireeCodegenTranslationInfoAttrGetTypeID();
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index ed5e12c..2280a69 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -14,6 +14,8 @@
ireeCodegenDispatchLoweringPassPipelineAttrGet
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
+ ireeCodegenGetExecutableVariantOps
+ ireeCodegenQueryMMAIntrinsics
ireeCodegenTranslationInfoAttrGet
ireeCodegenTranslationInfoAttrGetParameters
ireeCodegenTranslationInfoAttrGetTypeID
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index 0808927..5bd3b25 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -15,6 +15,8 @@
ireeCodegenDispatchLoweringPassPipelineAttrGet;
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
+ ireeCodegenGetExecutableVariantOps;
+ ireeCodegenQueryMMAIntrinsics;
ireeCodegenTranslationInfoAttrGet;
ireeCodegenTranslationInfoAttrGetParameters;
ireeCodegenTranslationInfoAttrGetTypeID;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index 11169bf..f92e98f 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -13,6 +13,8 @@
_ireeCodegenDispatchLoweringPassPipelineAttrGet
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
+_ireeCodegenGetExecutableVariantOps
+_ireeCodegenQueryMMAIntrinsics
_ireeCodegenTranslationInfoAttrGet
_ireeCodegenTranslationInfoAttrGetParameters
_ireeCodegenTranslationInfoAttrGetTypeID
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index f1eb776..612183d 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -1030,7 +1030,7 @@
SmallVector<IREE::HAL::ExecutableVariantOp>
getExecutableVariantOps(mlir::ModuleOp moduleOp) {
- llvm::SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
+ SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) {
executableVariantOps.push_back(executableOp);
});
@@ -1039,7 +1039,7 @@
SmallVector<IREE::GPU::MMAIntrinsic>
queryMMAIntrinsics(IREE::HAL::ExecutableVariantOp executableOp) {
- llvm::SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
+ SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) {
mmaIntrinsics = llvm::map_to_vector(
target.getWgp().getMma(),