Cleanup: Use upstream TransformInterpreterPassBase (#13633)
Transform dialect interpreter passes must still be defined in IREE, but
they can use the upstream
`mlir::transform::TransformInterpreterPassBase` implementation.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index f355025..82be1cd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -68,6 +68,7 @@
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:VectorDialect",
# IR
"@llvm-project//mlir:Analysis",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index c67a34a..ea43378 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -62,6 +62,7 @@
MLIRTensorDialect
MLIRTensorTransforms
MLIRTransformDialect
+ MLIRTransformDialectTransforms
MLIRVectorDialect
MLIRVectorTransformOps
MLIRVectorTransforms
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index 88c6bba..8523385 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -8,7 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
@@ -40,6 +39,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
@@ -54,7 +54,7 @@
/// This needs to be its own pass because the registration mechanism and ops
/// available are different than for other interpreters.
class TransformDialectInterpreterPass
- : public transform::iree_dialects::TransformInterpreterPassBase<
+ : public mlir::transform::TransformInterpreterPassBase<
TransformDialectInterpreterPass,
iree_compiler::TransformDialectInterpreterBase> {
public:
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir b/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir
index 4c63f9a..b085a7c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir
@@ -52,7 +52,6 @@
// -----
-// expected-error @below {{transform dialect interpreter failed}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
@@ -107,7 +106,6 @@
// -----
-// expected-error @below {{transform dialect interpreter failed}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_ops_invalid.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_ops_invalid.mlir
index d46eeef..8a5deaf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/transform_ops_invalid.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_ops_invalid.mlir
@@ -1,7 +1,6 @@
// RUN: iree-opt %s --split-input-file --iree-transform-dialect-interpreter --verify-diagnostics
module {
-// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{match registry not available}}
@@ -12,7 +11,6 @@
// -----
module {
-// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
@@ -24,7 +22,6 @@
// -----
module {
-// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
@@ -47,7 +44,6 @@
// -----
module attributes {test.iree_transform_do_not_match} {
-// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir
index 6c20ba9..012d007 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir
@@ -66,7 +66,6 @@
// Check that we reject constructs that try to apply create_async_groups
// on non-func op.
-// expected-error@below {{transform dialect interpreter failed}}
builtin.module {
func.func @copies_to_asyncs_invalid_op_input(%a: memref<1024x1024xf32>) {
// expected-note@below {{when applied to this op}}
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index e3fcb83..69cf778 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -393,9 +393,16 @@
"Optional filename containing a transform dialect specification to "
"apply. If left empty, the IR is assumed to contain one top-level "
"transform dialect operation somewhere in the module.">,
+ Option<"transformLibraryFileName",
+ "transform-library-file-name",
+ "std::string",
+ /*default=*/"\"\"",
+ "If non-empty, the name of the file containing definitions of "
+ "external symbols referenced in the transform script. "
+ "These definitions will be used to replace declarations.">,
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
/*default=*/"\"\"",
- "Select the operation with 'transform.iree_tag' attribute having "
+ "Select the operation with 'transform.target_tag' attribute having "
"the given value as payload IR root. This allows user control on "
"what operation to transform in debug mode, without requiring "
"intimate knowledge of the IREE nested pass pipeline.\\n"
@@ -403,7 +410,7 @@
"operation in the IREE pipeline, as the payload IR root.">,
Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string",
/*default=*/"\"\"",
- "Select the operation with 'transform.iree_tag' attribute having "
+ "Select the operation with 'transform.target_tag' attribute having "
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply in debug "
"mode, without requiring intimate knowledge of the IREE nested "
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 2771cda..f96d85b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -93,7 +93,6 @@
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
- "//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
@@ -121,6 +120,7 @@
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 02ffe5d..5b295a5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -75,7 +75,6 @@
IREELinalgExtTransforms
IREELinalgExtUtils
IREELinalgTransformDialect
- IREELinalgTransformDialectPasses
LLVMSupport
MLIRAffineDialect
MLIRAnalysis
@@ -102,6 +101,7 @@
MLIRTilingInterface
MLIRTosaDialect
MLIRTransformDialect
+ MLIRTransformDialectTransforms
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp
index ccaeacb..4acc8e3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
@@ -18,6 +17,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -30,8 +30,8 @@
/// Interpreter pass that applies transform dialect ops for dispatch region
/// formation. This needs to be its own pass because the registration mechanism
/// and ops available are different than for other interpreters.
-struct DispatchWithTransformDialect
- : public transform::iree_dialects::TransformInterpreterPassBase<
+class DispatchWithTransformDialect
+ : public mlir::transform::TransformInterpreterPassBase<
DispatchWithTransformDialect, DispatchWithTransformDialectBase> {
void getDependentDialects(DialectRegistry ®istry) const override {
// clang-format off
@@ -49,6 +49,7 @@
// clang-format on
}
+ public:
DispatchWithTransformDialect(StringRef transformFileName,
StringRef debugPayloadRootTag = StringRef(),
StringRef debugTransformRootTag = StringRef()) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 8f294c8..7205502 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -111,9 +111,16 @@
"Optional filename containing a transform dialect specification to "
"apply. If left empty, the IR is assumed to contain one top-level "
"transform dialect operation somewhere in the module.">,
+ Option<"transformLibraryFileName",
+ "transform-library-file-name",
+ "std::string",
+ /*default=*/"\"\"",
+ "If non-empty, the name of the file containing definitions of "
+ "external symbols referenced in the transform script. "
+ "These definitions will be used to replace declarations.">,
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
/*default=*/"\"\"",
- "Select the operation with 'transform.iree_tag' attribute having "
+ "Select the operation with 'transform.target_tag' attribute having "
"the given value as payload IR root. This allows user control on "
"what operation to transform in debug mode, without requiring "
"intimate knowledge of the IREE nested pass pipeline.\\n"
@@ -121,7 +128,7 @@
"operation in the IREE pipeline, as the payload IR root.">,
Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string",
/*default=*/"\"\"",
- "Select the operation with 'transform.iree_tag' attribute having "
+ "Select the operation with 'transform.target_tag' attribute having "
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply in debug "
"mode, without requiring intimate knowledge of the IREE nested "
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index a010d69..6ecc4e4 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -610,6 +610,7 @@
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h
deleted file mode 100644
index ae97eb4..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h
+++ /dev/null
@@ -1,152 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H
-#define IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include <memory>
-
-namespace mlir {
-class LogicalResult;
-class MLIRContext;
-class ModuleOp;
-class Operation;
-template <typename>
-class OwningOpRef;
-class Region;
-
-namespace transform {
-namespace iree_dialects {
-namespace detail {
-/// Template-free implementation of TransformInterpreterPassBase::initialize.
-LogicalResult
-interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
- std::shared_ptr<OwningOpRef<ModuleOp>> &module);
-
-/// Template-free implementation of
-/// TransformInterpreterPassBase::runOnOperation.
-LogicalResult interpreterBaseRunOnOperationImpl(
- Operation *target, StringRef passName,
- const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
- const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag);
-} // namespace detail
-
-/// Base class for transform dialect interpreter passes that can consume and
-/// dump transform dialect scripts in separate files. The pass is controlled by
-/// three string options:
-///
-/// - transformFileName: if non-empty, the name of the file containing the
-/// transform script. If empty, `debugTransformRootTag` is considered or the
-/// pass root operation must contain a single top-level transform op that
-/// will be interpreted.
-/// - debugPayloadRootTag: if non-empty, the value of the attribute named
-/// `kTransformIreeTagAttrName` indicating the single op that is considered
-/// the payload root of the transform interpreter; otherwise, the root
-/// operation of the pass is used.
-/// - debugTransformRootTag: if non-empty, the value of the attribute named
-/// `kTransformIreeTagAttrName` indicating the single top-level transform
-/// op contained in the payload root to be used as the entry point by the
-/// transform interpreter; mutually exclusive with `transformFileName`.
-///
-/// The pass runs the transform dialect interpreter as directed by the options.
-/// It also provides the mechanism to dump reproducers into stderr
-/// (-debug-only=iree-transform-dialect-dump-repro) or into a temporary file
-/// (-debug-only=iree-transform-dialect-save-repro) that can be used with this
-/// pass in a standalone mode.
-///
-/// Concrete passes must derive from this class instead of their generated base
-/// class (or PassWrapper), and supply themselves and the generated base class
-/// as template arguments. They are *not* expected to to implement `initialize`
-/// or `runOnOperation`. They *are* expected to call the copy constructor of
-/// this class in their copy constructors, short of which the file-based
-/// transform dialect script injection facility will become nonoperational.
-///
-/// Concrete passes may implement the `runBeforeInterpreter` and
-/// `runAfterInterpreter` to customize the behavior of the pass.
-template <typename Concrete, template <typename> typename GeneratedBase>
-class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
-public:
- TransformInterpreterPassBase() = default;
-
- TransformInterpreterPassBase(const TransformInterpreterPassBase &pass) {
- // TODO: if we really don't like shared_ptr, we could also clone the
- // transformModule here.
- sharedTransformModule = pass.sharedTransformModule;
- }
-
- LogicalResult initialize(MLIRContext *context) final {
-
-#define REQUIRE_PASS_OPTION(NAME) \
- static_assert( \
- std::is_same_v< \
- std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
- Pass::Option<std::string>>, \
- "required " #NAME " string pass option is missing")
-
- REQUIRE_PASS_OPTION(transformFileName);
- REQUIRE_PASS_OPTION(debugPayloadRootTag);
- REQUIRE_PASS_OPTION(debugTransformRootTag);
-
-#undef REQUIRE_PASS_OPTION
-
- StringRef transformFileName =
- static_cast<Concrete *>(this)->transformFileName;
- return detail::interpreterBaseInitializeImpl(context, transformFileName,
- sharedTransformModule);
- }
-
- /// Hook for passes to run additional logic in the pass before the
- /// interpreter. If failure is returned, the pass fails and the interpreter is
- /// not run.
- LogicalResult runBeforeInterpreter(Operation *) { return success(); }
-
- /// Hook for passes to run additional logic in the pass after the interpreter.
- /// Only runs if everything succeeded before. If failure is returned, the pass
- /// fails.
- LogicalResult runAfterInterpreter(Operation *) { return success(); }
-
- void runOnOperation() final {
- auto *pass = static_cast<Concrete *>(this);
- Operation *op = pass->getOperation();
- if (failed(pass->runBeforeInterpreter(op)) ||
- failed(detail::interpreterBaseRunOnOperationImpl(
- op, pass->getArgument(), sharedTransformModule,
- pass->transformFileName, pass->debugPayloadRootTag,
- pass->debugTransformRootTag)) ||
- failed(pass->runAfterInterpreter(op))) {
- return pass->signalPassFailure();
- }
- }
-
-private:
- // The parsed transform module to be used for transformations.
- // TODO: Figure a better way to build a transform module and transport it in
- // the proper places in the IR as it is transformed by IREE so that it is
- // available with better ownership semantics.
- // Note: we wrap the OwningOpRef to get the desired destruction mechanism.
- // Note: shared_ptr is not great but we know the sharedTransformModule is
- // readonly.
- // Alternatives comprise:
- // 1. no shared_ptr but copying the module with every pass clone that the
- // OpPassManager decides to perform.
- // 2. lifting ownership of the parsed transform module higher up in the
- // IREE stack. This may be only shift the problem as we have passes
- // building pass managers in IREE.
- // 3. build better support to embed the transformation module in the
- // input IR and transport it to the place of use in IREE. This is deemed
- // too intrusive atm.
- // 4. (future) config/resources mechanism that is being proposed in core?
- std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
-};
-
-} // namespace iree_dialects
-} // namespace transform
-} // namespace mlir
-#endif // IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt
index d10ff32..9c27d11 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_library(IREELinalgTransformDialectPasses
ExpertExpansion.cpp
TransformInterpreter.cpp
- TransformInterpreterPassBase.cpp
DEPENDS
mlir-headers
@@ -19,6 +18,7 @@
MLIRMemRefToLLVM
MLIRPass
MLIRTensorDialect
+ MLIRTransformDialectTransforms
MLIRTransforms
MLIRVectorDialect
MLIRVectorToLLVM
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp
index c9f12e9..141ba9b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp
@@ -8,7 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
@@ -25,6 +24,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Pass/Pass.h"
@@ -43,7 +43,7 @@
/// This needs to be its own pass because the registration mechanism and ops
/// available are different than for other interpreters.
class TransformDialectInterpreter
- : public transform::iree_dialects::TransformInterpreterPassBase<
+ : public mlir::transform::TransformInterpreterPassBase<
TransformDialectInterpreter, PassWrapperStub> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectInterpreter)
@@ -112,15 +112,21 @@
::llvm::cl::init("")};
Pass::Option<std::string> debugPayloadRootTag{
*this, "debug-payload-root-tag",
- ::llvm::cl::desc("Select the operation with 'transform.iree_tag' "
+ ::llvm::cl::desc("Select the operation with 'transform.target_tag' "
"attribute having the given value as payload IR root."),
::llvm::cl::init("")};
Pass::Option<std::string> debugTransformRootTag{
*this, "debug-transform-root-tag",
::llvm::cl::desc(
- "Select the operation with 'transform.iree_tag' attribute having the "
- "given value as container IR for top-level transform ops."),
+ "Select the operation with 'transform.target_tag' attribute having "
+ "the given value as container IR for top-level transform ops."),
::llvm::cl::init("")};
+ Pass::Option<std::string> transformLibraryFileName{
+ *this, "transform-library-file-name",
+ llvm::cl::desc(
+ "Optional name of the file containing transform dialect symbol "
+ "definitions to be injected into the transform module."),
+ llvm::cl::init("")};
};
struct DropSchedulePass : public PassWrapper<DropSchedulePass, Pass> {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreterPassBase.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreterPassBase.cpp
deleted file mode 100644
index 53494a1..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreterPassBase.cpp
+++ /dev/null
@@ -1,483 +0,0 @@
-// Copyright 2023 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
-#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Parser/Parser.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/FileUtilities.h"
-#include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/FileSystem.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/Path.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-#define DEBUG_TYPE "transform-dialect-interpreter"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-#define DEBUG_TYPE_DUMP_STDERR "iree-transform-dialect-dump-repro"
-#define DEBUG_TYPE_DUMP_FILE "iree-transform-dialect-save-repro"
-
-/// Name of the attribute used for targeting the transform dialect interpreter
-/// at specific operations.
-constexpr static llvm::StringLiteral kTransformIreeTagAttrName =
- "transform.iree_tag";
-/// Value of the attribute indicating the root payload operation.
-constexpr static llvm::StringLiteral kTransformIreeTagPayloadRootValue =
- "iree_payload_root";
-/// Value of the attribute indicating the container of transform operations
-/// (containing the top-level transform operation).
-constexpr static llvm::StringLiteral kTransformIreeTagTransformContainerValue =
- "iree_transform_container";
-
-/// Utility to parse the content of a `transformFileName` mlir file containing
-/// a transform dialect specification.
-static LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
- llvm::StringRef transformFileName,
- OwningOpRef<ModuleOp> &transformModule) {
- if (transformFileName.empty()) {
- LLVM_DEBUG(
- DBGS() << "no transform file name specified, assuming the transform "
- "module is embedded in the IR next to the top-level\n");
- return success();
- }
- // Parse transformFileName content into a ModuleOp.
- std::string errorMessage;
- auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
- if (!memoryBuffer) {
- llvm::errs() << "failed to parse transform file: " << transformFileName
- << "\n";
- return failure();
- }
- // Tell sourceMgr about this buffer, the parser will pick it up.
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
- transformModule =
- OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
- return success();
-}
-
-/// Utility to extract the `TransformOpInterface` ops that have the trait
-/// `PossibleTopLevelTransformOpTrait`.
-static LogicalResult extractTopLevelTransformOps(
- Region &r, SmallVectorImpl<transform::TransformOpInterface> &res) {
- assert(r.getBlocks().size() == 1 &&
- "Expected single-block region to extract transform ops from");
- r.walk<WalkOrder::PreOrder>([&](transform::TransformOpInterface transform) {
- if (transform->hasTrait<transform::PossibleTopLevelTransformOpTrait>()) {
- assert(llvm::none_of(res, [&](transform::TransformOpInterface seen) {
- return seen->isAncestor(transform);
- }));
- res.push_back(transform);
- return WalkResult::skip();
- }
- return WalkResult::advance();
- });
- return success();
-}
-
-/// Utility to run a transform dialect specification contained in a
-/// `transformRegion`, on a `target` op.
-/// Since the transform dialect may use PDL which may modify the IR, the
-/// underlying implementation clones the transform dialect operations before
-/// applying them.
-static LogicalResult applyTransformsInRegion(Region &transformRegion,
- Operation *target) {
- SmallVector<transform::TransformOpInterface> transforms;
- if (failed(extractTopLevelTransformOps(transformRegion, transforms)))
- return failure();
-
- for (transform::TransformOpInterface transform : transforms) {
- // TransformState::applyTransform requires that the parent region is a
- // proper ancestor of the transform op to perform SSA liveness assertions.
- // In multithreaded state however, we cannot clone into `transformRegion` so
- // we build a new single-block region and clone the transform op into it.
- Region r;
- OpBuilder b(target->getContext());
- b.createBlock(&r);
- transform::TransformOptions options;
-#ifndef NDEBUG
- options = options.enableExpensiveChecks();
-#endif
- auto xform = cast<transform::TransformOpInterface>(b.clone(*transform));
- auto g = llvm::make_scope_exit([&]() { xform->erase(); });
- if (failed(transform::applyTransforms(target, xform, {}, options)))
- return failure();
- }
- return success();
-}
-
-/// Finds the single top-level transform operation with `root` as ancestor.
-/// Reports an error if there is more than one such operation and returns the
-/// first one found. Reports an error returns nullptr if no such operation
-/// found.
-static Operation *findTopLevelTransform(Operation *root, StringRef debugStr) {
- ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
- WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
- [&](::mlir::transform::TransformOpInterface transformOp) {
- if (!topLevelTransform) {
- topLevelTransform = transformOp;
- return WalkResult::skip();
- }
- auto diag = transformOp.emitError()
- << "more than one top-level transform op";
- diag.attachNote(topLevelTransform.getLoc())
- << "previous top-level transform op";
- return WalkResult::interrupt();
- });
- if (walkResult.wasInterrupted())
- return nullptr;
- if (!topLevelTransform) {
- auto diag = root->emitError()
- << "could not find a nested top-level transform op";
- diag.attachNote() << "use the '" << debugStr
- << "' option to provide transform as external file";
- return nullptr;
- }
- return topLevelTransform;
-}
-
-/// Finds an operation nested in `root` that has the transform dialect tag
-/// attribute with the value specified as `tag`. Assumes only one operation
-/// may have the tag. Returns nullptr if there is no such operation.
-static Operation *findOpWithTag(Operation *root, StringRef tagKey,
- StringRef tagValue) {
- Operation *found = nullptr;
- root->walk<WalkOrder::PreOrder>([tagKey, tagValue, &found](Operation *op) {
- auto attr = op->getAttrOfType<StringAttr>(tagKey);
- if (!attr || attr.getValue() != tagValue)
- return WalkResult::advance();
-
- assert(found == nullptr && "more than one op with the same tag");
- found = op;
-
- // In debug mode, continue the traversal to see if the tag is not
- // duplicated. This is only necessary to ensure that the assert above is
- // triggered. In the non-debug mode, assert is not performed and we can
- // sparse some cycles by not iterating further.
-#ifndef NDEBUG
- return WalkResult::advance();
-#else
- return WalkResult::interrupt();
-#endif // NDEBUG
- });
- return found;
-}
-
-/// Returns the ancestor of `target` that doesn't have a parent.
-static Operation *getRootOperation(Operation *target) {
- Operation *root = target;
- while (root->getParentOp())
- root = root->getParentOp();
- return root;
-}
-
-/// Prints the CLI command running the repro with the current path.
-static llvm::raw_ostream &
-printIreeOptReproCall(llvm::raw_ostream &os, StringRef rootOpName,
- StringRef passName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag) {
- os << llvm::formatv("iree-opt --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"",
- rootOpName, passName, debugPayloadRootTag.getArgStr(),
- debugPayloadRootTag.empty()
- ? StringRef(kTransformIreeTagPayloadRootValue)
- : debugPayloadRootTag,
- debugTransformRootTag.getArgStr(),
- debugTransformRootTag.empty()
- ? StringRef(kTransformIreeTagTransformContainerValue)
- : debugTransformRootTag);
- return os;
-}
-
-/// Prints the module rooted at `root` to `os` and appends
-/// `transformContainer` if it is not nested in `root`.
-static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os,
- Operation *root,
- Operation *transformContainer) {
- root->print(os);
- if (!root->isAncestor(transformContainer)) {
- transformContainer->print(os);
- }
- return os;
-}
-
-/// Saves the payload and the transform IR into a temporary file and reports
-/// the file name to `os`.
-static void
-saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
- Operation *transformContainer, StringRef passName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag) {
- using llvm::sys::fs::TempFile;
- Operation *root = getRootOperation(target);
-
- SmallVector<char, 128> tmpPath;
- llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath);
- llvm::sys::path::append(tmpPath, "iree_transform_dialect_%%%%%%.mlir");
- llvm::Expected<TempFile> tempFile = TempFile::create(tmpPath);
- if (!tempFile) {
- os << "could not open temporary file to save the repro\n";
- return;
- }
-
- llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false);
- printModuleForRepro(fout, root, transformContainer);
- fout.flush();
- std::string filename = tempFile->TmpName;
-
- if (tempFile->keep()) {
- os << "could not preserve the temporary file with the repro\n";
- return;
- }
-
- os << "=== Transform Interpreter Repro ===\n";
- printIreeOptReproCall(os, root->getName().getStringRef(), passName,
- debugPayloadRootTag, debugTransformRootTag)
- << " " << filename << "\n";
- os << "===================================\n";
-}
-
-namespace {
-/// Position of an op within a containing op.
-struct OpPosition {
- /// Number of the containing region in the parent op.
- size_t regionNumber;
- /// Offset of the containing block in the list of blocks of the parent region.
- size_t blockOffset;
- /// Offset of the operation in the list of operations of the parent block.
- size_t opOffset;
-};
-} // namespace
-
-/// Populates `positions` with the relative positions of `target` in its
-/// ancestors.
-static void findOpPositionInRoot(Operation *target,
- SmallVectorImpl<OpPosition> &positions) {
- // Root operation may or may not have a parent block and has no parent
- // region. Even if it has a parent block, we don't need its position in the
- // block because we will have a pointer to root.
- for (; target->getParentOp() != nullptr; target = target->getParentOp()) {
- size_t posInBlock =
- std::distance(target->getBlock()->begin(), target->getIterator());
- size_t blockPos = std::distance(target->getParentRegion()->begin(),
- target->getBlock()->getIterator());
- size_t regionNo = target->getParentRegion()->getRegionNumber();
- positions.emplace_back(OpPosition{regionNo, blockPos, posInBlock});
- }
-}
-
-/// Finds an op located at the given stack of positions (e.g., the last position
-/// is at the root, and the first position is at the immediate parent) relative
-/// to the given root operation. Expects the location to exist.
-static Operation *getOpAtPosition(Operation *root,
- ArrayRef<OpPosition> positions) {
- Operation *op = root;
- for (const OpPosition &pos : llvm::reverse(positions)) {
- Region ®ion = op->getRegion(pos.regionNumber);
- Block &block = *std::next(region.begin(), pos.blockOffset);
- op = &*std::next(block.begin(), pos.opOffset);
- }
- return op;
-}
-
-/// Finds the clone of `original` in the cloned copy of the root operation, i.e.
-/// the operation with no parent, using its relative offsets in the parent
-/// parent lists of blocks and regions. Note that this performs linear
-/// traversals of blocks and regions along the path to root, but is arguably
-/// preferable to storing the entire mapping between all cloned operations.
-static Operation *findCloned(Operation *original, Operation *cloneRoot) {
- SmallVector<OpPosition> opPositions;
- findOpPositionInRoot(original, opPositions);
- return getOpAtPosition(cloneRoot, opPositions);
-}
-
-// Optionally perform debug actions requested by the user to dump IR and a
-// repro to stderr and/or a file.
-static void performOptionalDebugActions(
- Operation *target, Region *transformRegion, StringRef passName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag) {
- MLIRContext *context = target->getContext();
-
- // If we are not planning to print, bail before we start doing expensive
- // copying.
- bool hasDebugFlags = false;
- DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; });
- DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; });
- if (!hasDebugFlags)
- return;
-
- // Since the pass may be running in parallel on multiple parts of the same
- // root operation, clone it before attaching debug attributes as it would
- // otherwise create a race. While we could avoid cloning when multithreading
- // is disabled or when the attributes are already present, this is only
- // necessary in debug builds that are not performance-critical and can afford
- // an extra copy.
- Operation *root = getRootOperation(target);
- OwningOpRef<Operation *> rootClone(root->clone());
- Operation *debugRoot = rootClone.get();
- Operation *debugTarget = findCloned(target, debugRoot);
-
- Operation *transformContainer = transformRegion->getParentOp();
- assert(transformContainer && "unexpected detached transform region");
- OwningOpRef<Operation *> maybeTransformContainerClone;
- Operation *debugTransformContainer;
- if (root->isAncestor(transformContainer)) {
- debugTransformContainer = findCloned(transformContainer, debugRoot);
- } else {
- maybeTransformContainerClone =
- OwningOpRef<Operation *>(transformContainer->clone());
- debugTransformContainer = maybeTransformContainerClone.get();
- }
-
- // Add temporary debug / repro attributes, these must never leak out.
- if (debugPayloadRootTag.empty()) {
- debugTarget->setAttr(
- kTransformIreeTagAttrName,
- StringAttr::get(context, kTransformIreeTagPayloadRootValue));
- }
- if (debugTransformRootTag.empty()) {
- debugTransformContainer->setAttr(
- kTransformIreeTagAttrName,
- StringAttr::get(context, kTransformIreeTagTransformContainerValue));
- }
-
- DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
- llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
- printIreeOptReproCall(llvm::dbgs() << "cat <<EOF | ",
- debugRoot->getName().getStringRef(), passName,
- debugPayloadRootTag, debugTransformRootTag)
- << "\n";
- printModuleForRepro(llvm::dbgs(), debugRoot, debugTransformContainer);
- llvm::dbgs() << "\nEOF\n";
- llvm::dbgs() << "===================================\n";
- });
- DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
- saveReproToTempFile(llvm::dbgs(), debugTarget, debugTransformContainer,
- passName, debugPayloadRootTag, debugTransformRootTag);
- });
-}
-
-LogicalResult
-transform::iree_dialects::detail::interpreterBaseRunOnOperationImpl(
- Operation *target, StringRef passName,
- const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
- const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag) {
- bool parsedTransform = (sharedTransformModule && *sharedTransformModule);
-
- // Step 1
- // ------
- // Get the default payloadRoot and transformRegion that one expects
- // when running the IREE nested pass pipeline or the interpreter.
- Operation *payloadRoot = target;
- Region *transformRegion = nullptr;
- // If a parsed transform was specified separately, use it immediately.
- // Otherwise, the transform is embedded in the IR: go inspect the IR and
- // get the first top-level transform we find.
- if (parsedTransform) {
- transformRegion = &(*sharedTransformModule)->getRegion();
- } else {
- // TODO: In large IR we will likely want more control in selecting a
- // particular transform to focus on, this may warrant a user-specified
- // attribute that one would manually injected in the IR when operating in
- // interpreted mode.
- Operation *topLevelTransform =
- findTopLevelTransform(target, transformFileName.getArgStr());
- if (!topLevelTransform)
- return failure();
- transformRegion = topLevelTransform->getParentRegion();
- }
- assert(transformRegion && "unexpected detached root transform op");
-
- // Step 2
- // ------
- // Optionally override payloadRoot if the debugPayloadRootTag was passed.
- //
- // If debugPayloadRootTag was passed, then we are in user-specified selection
- // of the transformed IR. This corresponds to REPL debug mode.
- // Otherwise, just apply to `target`, which is what the IREE nested
- // pipeline wants to operate on.
- if (!debugPayloadRootTag.empty()) {
- payloadRoot =
- findOpWithTag(target, kTransformIreeTagAttrName, debugPayloadRootTag);
- if (!payloadRoot) {
- target->emitError() << "couldn't find the root payload op with "
- << kTransformIreeTagAttrName << "=\""
- << kTransformIreeTagPayloadRootValue
- << "\" attribute";
- return failure();
- }
- }
-
- // Step 3
- // ------
- // Optionally override transformRegion if the debugTransformRootTag was
- // passed.
- //
- // If debugTransformRootTag was passed, then we are in user-specified
- // selection of the transforming IR. This corresponds to REPL debug mode.
- // Otherwise, just apply to the existing `transformRegion`, which is what
- // the IREE nested pipeline wants to operate on.
- if (!debugTransformRootTag.empty()) {
- Operation *transformRoot =
- findOpWithTag(transformRegion->getParentOp(), kTransformIreeTagAttrName,
- kTransformIreeTagTransformContainerValue);
- if (!transformRoot) {
- transformRegion->getParentOp()->emitError()
- << "couldn't find the transform container op with "
- << kTransformIreeTagAttrName << "=\""
- << kTransformIreeTagTransformContainerValue << "\" attribute";
- return failure();
- }
- if (transformRoot->getNumRegions() != 1 ||
- !transformRoot->getRegion(0).hasOneBlock()) {
- transformRoot->emitError() << "expected transform container op to have "
- "one single-block region";
- return failure();
- }
- transformRegion = &transformRoot->getRegion(0);
- }
-
- // Step 4
- // ------
- // Optionally perform debug actions requested by the user to dump IR and a
- // repro to stderr and/or a file.
- performOptionalDebugActions(target, transformRegion, passName,
- debugPayloadRootTag, debugTransformRootTag);
-
- // Step 5
- // ------
- // Apply the transform to the IR
- // TODO: lift this assertion.
- assert(transformRegion->getBlocks().size() == 1 &&
- "expected single-region block");
- if (failed(applyTransformsInRegion(*transformRegion, payloadRoot))) {
- payloadRoot->emitError() << "transform dialect interpreter failed";
- return failure();
- }
-
- return success();
-}
-
-LogicalResult transform::iree_dialects::detail::interpreterBaseInitializeImpl(
- MLIRContext *context, StringRef transformFileName,
- std::shared_ptr<OwningOpRef<ModuleOp>> &module) {
- OwningOpRef<ModuleOp> parsed;
- if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
- return failure();
-
- module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
- return success();
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
index 2b02c8c..950351b 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -1,6 +1,5 @@
// RUN: iree-dialects-opt --transform-dialect-interpreter --split-input-file --verify-diagnostics --allow-unregistered-dialect %s
-// expected-error @below {{transform dialect interpreter failed}}
module {
func.func public @no_outlining() {
// expected-note @below {{target op}}