[Codegen][Tuner] attr verifier for tuning specs (#19486)
This PR is relevant to task in
https://github.com/iree-org/iree/issues/19214: add [a discardable attr
verifier](https://mlir.llvm.org/docs/DefiningDialects/#discardable-attribute-verification)
for entry points iree_codegen.tuning_spec_entrypoint
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
index fdbadfe..1cdbfb1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/Verifier.h"
#define DEBUG_TYPE "iree-codegen-link-tuning-specs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -53,27 +54,6 @@
});
}
-// Returns true iff the entrypoint has the following signature:
-// ```
-// transform.named_sequence @name(%arg0: !transform.any_op) ->
-// (!transform.any_op)
-// ```
-static LogicalResult validateTuningSpec(NamedSequenceOp op) {
- ArrayRef<Type> resTypes = op.getFunctionType().getResults();
- if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
- return op.emitWarning()
- << "Tuning spec entry point expected to return any_op";
- }
-
- ArrayRef<Type> argTypes = op.getArgumentTypes();
- if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
- return op.emitWarning() << "Tuning spec entry point expected to have a "
- "single any_op argument";
- }
-
- return success();
-}
-
static bool consumesInputOp(NamedSequenceOp op) {
if (op.getArgAttr(0, kArgConsumedAttrName)) {
return true;
@@ -81,7 +61,7 @@
return false;
}
-static NamedSequenceOp
+static FailureOr<NamedSequenceOp>
emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
OpBuilder builder(module->getContext());
builder.setInsertionPointToEnd(module.getBody());
@@ -144,6 +124,11 @@
}
builder.create<transform::YieldOp>(loc, operand);
+
+ if (failed(mlir::verify(module))) {
+ return module.emitError("Linked tuning spec failed to verify");
+ }
+
return newSpec;
}
@@ -169,13 +154,6 @@
llvm::append_range(tuningSpecs, findTuningSpecs(nested));
}
- for (NamedSequenceOp spec : tuningSpecs) {
- LDBG("Found tuning spec: " << spec.getSymName());
- if (failed(validateTuningSpec(spec))) {
- return failure();
- }
- }
-
size_t numConsumedSpecs = llvm::count_if(tuningSpecs, consumesInputOp);
if (numConsumedSpecs > 0 && numConsumedSpecs != tuningSpecs.size()) {
LDBG("Only " << numConsumedSpecs << " tuning specs out of "
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp
index 7495f43..db36cbf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/Verifier.h"
#include "mlir/Support/FileUtilities.h"
#define DEBUG_TYPE "iree-codegen-materialize-tuning-specs"
@@ -138,8 +139,19 @@
// Load the library through the codegen dialect so that we cache the parsed
// module.
- return dialect.getOrParseTransformLibraryModule(defaultTuningSpecName,
- *defaultTuningSpecSource);
+ FailureOr<ModuleOp> defaultTransformLibrary =
+ dialect.getOrParseTransformLibraryModule(defaultTuningSpecName,
+ *defaultTuningSpecSource);
+
+#ifndef NDEBUG
+ if (succeeded(defaultTransformLibrary) &&
+ failed(mlir::verify(*defaultTransformLibrary)))
+ return (*defaultTransformLibrary).emitError()
+ << "Default tuning spec " << defaultTuningSpecName
+ << " failed to verify";
+#endif
+
+ return defaultTransformLibrary;
}
static FailureOr<DenseElementsAttr>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index f0652d2..3834d0a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -96,6 +96,7 @@
"vector_layout_analysis.mlir",
"vectorize_memref_copy.mlir",
"vectorize_tensor_pad.mlir",
+ "verify_tuning_specs.mlir",
"verify_workgroup_distribution.mlir",
"vmvx_materialize_encoding.mlir",
],
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 2d707f6..2ef1b63 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -92,6 +92,7 @@
"vector_layout_analysis.mlir"
"vectorize_memref_copy.mlir"
"vectorize_tensor_pad.mlir"
+ "verify_tuning_specs.mlir"
"verify_workgroup_distribution.mlir"
"vmvx_materialize_encoding.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/verify_tuning_specs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/verify_tuning_specs.mlir
new file mode 100644
index 0000000..aede375
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/verify_tuning_specs.mlir
@@ -0,0 +1,53 @@
+// RUN: iree-opt --verify-diagnostics --split-input-file %s
+
+module @foo_module attributes { transform.with_named_sequence } {
+ func.func @baz(%arg0: i32) -> () {
+ return
+ }
+ transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+ attributes { iree_codegen.something } {
+ transform.yield %arg0 : !transform.any_op
+ }
+ // expected-error @+1{{'iree_codegen.tuning_spec_entrypoint' attribute must be a UnitAttr}}
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+ attributes { iree_codegen.tuning_spec_entrypoint = "foo" } {
+ transform.yield %arg0 : !transform.any_op
+ }
+}
+
+// -----
+
+module @foo_module attributes { transform.with_named_sequence } {
+ // expected-error @+1{{Tuning spec entry point expected to have a single any_op argument}}
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.any_op {transform.readonly}) -> !transform.any_op
+ attributes { iree_codegen.tuning_spec_entrypoint } {
+ transform.yield %arg0 : !transform.any_op
+ }
+}
+
+// -----
+
+module @foo_module attributes { transform.with_named_sequence } {
+ // expected-error @+1{{Tuning spec entry point expected to have a single any_op argument}}
+ transform.named_sequence @foo(%arg0: i32) -> !transform.any_op
+ attributes { iree_codegen.tuning_spec_entrypoint } {}
+}
+
+// -----
+
+module @foo_module attributes { transform.with_named_sequence } {
+ // expected-error @+1{{Tuning spec entry point expected to return any_op}}
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> i32
+ attributes { iree_codegen.tuning_spec_entrypoint } {
+ %0 = arith.constant 0 : i32
+ transform.yield %0 : i32
+ }
+}
+
+// -----
+
+module @foo_module attributes { transform.with_named_sequence } {
+ // expected-error @+1{{Tuning spec entry point expected to return any_op}}
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly})
+ attributes { iree_codegen.tuning_spec_entrypoint } {}
+}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp
index 9116691..4a2281e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp.inc"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
namespace mlir::iree_compiler::IREE::Codegen {
@@ -45,4 +46,46 @@
>();
}
+LogicalResult
+IREECodegenDialect::verifyOperationAttribute(Operation *op,
+ NamedAttribute attribute) {
+ StringRef symbol = attribute.getName().strref();
+ Attribute attr = attribute.getValue();
+
+ // This function verifies the validity of a specific operation attribute.
+ // - If the attribute's name matches `kTuningSpecEntrypointAttrName`
+ // ("iree_codegen.tuning_spec_entrypoint"):
+ // 1. The attribute value must be a UnitAttr.
+ // 2. If the operation is a transform::NamedSequenceOp:
+ // - The operation's function signature must satisfy the following:
+ // a. It must have exactly one result type, and the result must be of
+ // type `transform::AnyOpType`.
+ // b. It must have exactly one argument type, and the argument must be
+ // of type `transform::AnyOpType`.
+
+ if (symbol != kTuningSpecEntrypointAttrName)
+ return success();
+
+ if (!isa<UnitAttr>(attr)) {
+ return op->emitError("'") << symbol << "' attribute must be a UnitAttr";
+ }
+
+ if (auto namedSeqOp = dyn_cast<transform::NamedSequenceOp>(op)) {
+ ArrayRef<Type> resTypes = namedSeqOp.getFunctionType().getResults();
+ if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
+ return namedSeqOp.emitError()
+ << "Tuning spec entry point expected to return any_op";
+ }
+
+ ArrayRef<Type> argTypes = namedSeqOp.getArgumentTypes();
+ if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
+ return namedSeqOp.emitError()
+ << "Tuning spec entry point expected to have a "
+ "single any_op argument";
+ }
+ }
+
+ return success();
+}
+
} // namespace mlir::iree_compiler::IREE::Codegen
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td
index 35775d0..a51ff09 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td
@@ -68,6 +68,7 @@
std::mutex libraryMutex;
}];
let useDefaultAttributePrinterParser = 1;
+ let hasOperationAttrVerify = 1;
}
def AnyRankedTensorOrMemRefType : AnyTypeOf<[AnyRankedTensor, AnyMemRef]>;