[Codegen] Add ability to specify transform dialect libraries (#14788)
This adds the ability to specify transform dialect strategies through
a symbol pointing to a library call. This is currently available through
two flags:
`--iree-codegen-use-transform-dialect-strategy=[@<symbol_name>|filename]`
`--iree-codegen-transform-dialect-library=filename`
The transform library is loaded and cached in the IREE codegen dialect
for subsequent invocations within the MaterializeUserConfigs pass. Then,
the loaded dialect is immediately used with the symbol name referenced
by the transform dialect usage flag. If a filename is specified instead,
that is broadcasted to the transform dialect interpreter (intended for
microbenchmarking).
If the symbol applies successfully, this will send the result through
normal IREE codegen. This gives users the option to replace the
translation info on the export op with a `<None>` pipeline to send it
through <BACKEND>LowerExecutableTarget unperterbed (thereby skipping
the initial tile + distribute and bufferization).
Additionally this unifies the way the transform dialect testing flags
currently duplicated across backends.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index f53f616..2f3d2f6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -64,84 +64,6 @@
],
)
-# TODO: If the layering causes concerns then the transform dialect interpreter
-# should be one level above everything: it is a mechanism by which
-# transformations are applied to any IR and needs to register all the dialects
-# that may be produced.
-# In particular, a single IREE-side transform interpreter is enough to perform
-# all kind of transformations and not just codegen.
-# This is an opportunity to retire the specific interpreter that is used for
-# creating dispatch regions with the transform dialect, but only once the
-# layering is correct.
-iree_compiler_cc_library(
- name = "TransformDialectInterpreterPass",
- srcs = [
- "TransformDialectInterpreterPass.cpp",
- ],
- deps = [
- ":PassHeaders",
- ":PassesIncGen",
- # Dialects
- "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
- "//compiler/src/iree/compiler/Dialect/Flow/IR",
- "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
- "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
- "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
- "@llvm-project//mlir:AffineDialect",
- "@llvm-project//mlir:AffineUtils",
- "@llvm-project//mlir:AsyncDialect",
- "@llvm-project//mlir:ArithDialect",
- "@llvm-project//mlir:ArithUtils",
- "@llvm-project//mlir:BufferizationDialect",
- "@llvm-project//mlir:BufferizationTransforms",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:GPUDialect",
- "@llvm-project//mlir:LinalgDialect",
- "@llvm-project//mlir:LLVMDialect",
- "@llvm-project//mlir:PDLDialect",
- "@llvm-project//mlir:PDLInterpDialect",
- "@llvm-project//mlir:SCFDialect",
- "@llvm-project//mlir:SCFUtils",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:TransformDialect",
- "@llvm-project//mlir:TransformDialectTransforms",
- "@llvm-project//mlir:VectorDialect",
- # IR
- "@llvm-project//mlir:Analysis",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:Rewrite",
- # Interfaces
- # Transforms (needed mostly for the BufferizableOpInterfaceImpl)
- "@llvm-project//mlir:ArithTransforms",
- "@llvm-project//mlir:LinalgTransforms",
- "@llvm-project//mlir:SCFTransforms",
- "@llvm-project//mlir:TensorTransforms",
- "@llvm-project//mlir:VectorTransforms",
- # Other Stuff
- "//compiler/src/iree/compiler/Utils",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:DialectUtils",
- # TransformStrategies
- "//compiler/src/iree/compiler/Codegen/TransformStrategies/Common:TransformStrategies",
- # TransformExtensions (needed for registration in the pass)
- "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
- "//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions",
- "//compiler/src/iree/compiler/Dialect/Flow/TransformExtensions:FlowExtensions",
- "//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions",
- "//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
- "@llvm-project//mlir:AffineTransformOps",
- "@llvm-project//mlir:BufferizationTransformOps",
- "@llvm-project//mlir:GPUTransformOps",
- "@llvm-project//mlir:LinalgTransformOps",
- "@llvm-project//mlir:MemRefTransformOps",
- "@llvm-project//mlir:SCFTransformOps",
- "@llvm-project//mlir:TensorTransformOps",
- "@llvm-project//mlir:VectorTransformOps",
- ],
-)
-
iree_compiler_cc_library(
name = "Common",
srcs = [
@@ -271,3 +193,85 @@
"@llvm-project//mlir:ViewLikeInterface",
],
)
+
+# TODO: If the layering causes concerns then the transform dialect interpreter
+# should be one level above everything: it is a mechanism by which
+# transformations are applied to any IR and needs to register all the dialects
+# that may be produced.
+# In particular, a single IREE-side transform interpreter is enough to perform
+# all kind of transformations and not just codegen.
+# This is an opportunity to retire the specific interpreter that is used for
+# creating dispatch regions with the transform dialect, but only once the
+# layering is correct.
+iree_compiler_cc_library(
+ name = "TransformDialectInterpreterPass",
+ srcs = [
+ "CommonDialectRegistration.cpp",
+ "MaterializeUserConfigs.cpp",
+ "TransformDialectInterpreterPass.cpp",
+ ],
+ deps = [
+ ":Common",
+ ":PassHeaders",
+ ":PassesIncGen",
+ # Dialects
+ "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
+ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:AsyncDialect",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ArithUtils",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:BufferizationTransforms",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:PDLInterpDialect",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFUtils",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
+ "@llvm-project//mlir:VectorDialect",
+ # IR
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Rewrite",
+ # Interfaces
+ # Transforms (needed mostly for the BufferizableOpInterfaceImpl)
+ "@llvm-project//mlir:ArithTransforms",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:TensorTransforms",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorTransforms",
+ # Other Stuff
+ "//compiler/src/iree/compiler/Utils",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:DialectUtils",
+ # TransformStrategies
+ "//compiler/src/iree/compiler/Codegen/TransformStrategies/Common:TransformStrategies",
+ # TransformExtensions (needed for registration in the pass)
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+ "//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions",
+ "//compiler/src/iree/compiler/Dialect/Flow/TransformExtensions:FlowExtensions",
+ "//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions",
+ "//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
+ "@llvm-project//mlir:AffineTransformOps",
+ "@llvm-project//mlir:BufferizationTransformOps",
+ "@llvm-project//mlir:GPUTransformOps",
+ "@llvm-project//mlir:LinalgTransformOps",
+ "@llvm-project//mlir:MemRefTransformOps",
+ "@llvm-project//mlir:SCFTransformOps",
+ "@llvm-project//mlir:TensorTransformOps",
+ "@llvm-project//mlir:VectorTransformOps",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index f047fe1..f23fc30 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -49,67 +49,6 @@
iree_cc_library(
NAME
- TransformDialectInterpreterPass
- SRCS
- "TransformDialectInterpreterPass.cpp"
- DEPS
- ::PassHeaders
- ::PassesIncGen
- IREEDialectsTransforms
- IREELinalgExtDialect
- IREELinalgExtTransformOps
- IREELinalgTransformDialect
- LLVMSupport
- MLIRAffineDialect
- MLIRAffineTransformOps
- MLIRAffineUtils
- MLIRAnalysis
- MLIRArithDialect
- MLIRArithTransforms
- MLIRArithUtils
- MLIRAsyncDialect
- MLIRBufferizationDialect
- MLIRBufferizationTransformOps
- MLIRBufferizationTransforms
- MLIRFuncDialect
- MLIRGPUDialect
- MLIRGPUTransformOps
- MLIRIR
- MLIRLLVMDialect
- MLIRLinalgDialect
- MLIRLinalgTransformOps
- MLIRLinalgTransforms
- MLIRMemRefTransformOps
- MLIRPDLDialect
- MLIRPDLInterpDialect
- MLIRPass
- MLIRRewrite
- MLIRSCFDialect
- MLIRSCFTransformOps
- MLIRSCFTransforms
- MLIRSCFUtils
- MLIRSupport
- MLIRTensorDialect
- MLIRTensorTransformOps
- MLIRTensorTransforms
- MLIRTransformDialect
- MLIRTransformDialectTransforms
- MLIRVectorDialect
- MLIRVectorTransformOps
- MLIRVectorTransforms
- iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
- iree::compiler::Codegen::Dialect::IREECodegenDialect
- iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
- iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
- iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
- iree::compiler::Dialect::Flow::IR
- iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
- iree::compiler::Utils
- PUBLIC
-)
-
-iree_cc_library(
- NAME
Common
HDRS
"BufferizationAnalysis.h"
@@ -236,4 +175,69 @@
PUBLIC
)
+iree_cc_library(
+ NAME
+ TransformDialectInterpreterPass
+ SRCS
+ "CommonDialectRegistration.cpp"
+ "MaterializeUserConfigs.cpp"
+ "TransformDialectInterpreterPass.cpp"
+ DEPS
+ ::Common
+ ::PassHeaders
+ ::PassesIncGen
+ IREEDialectsTransforms
+ IREELinalgExtDialect
+ IREELinalgExtTransformOps
+ IREELinalgTransformDialect
+ LLVMSupport
+ MLIRAffineDialect
+ MLIRAffineTransformOps
+ MLIRAffineUtils
+ MLIRAnalysis
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRArithUtils
+ MLIRAsyncDialect
+ MLIRBufferizationDialect
+ MLIRBufferizationTransformOps
+ MLIRBufferizationTransforms
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRGPUTransformOps
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRLinalgDialect
+ MLIRLinalgTransformOps
+ MLIRLinalgTransforms
+ MLIRMemRefTransformOps
+ MLIRPDLDialect
+ MLIRPDLInterpDialect
+ MLIRPass
+ MLIRRewrite
+ MLIRSCFDialect
+ MLIRSCFTransformOps
+ MLIRSCFTransforms
+ MLIRSCFUtils
+ MLIRSupport
+ MLIRTensorDialect
+ MLIRTensorTransformOps
+ MLIRTensorTransforms
+ MLIRTransformDialect
+ MLIRTransformDialectTransforms
+ MLIRTransforms
+ MLIRVectorDialect
+ MLIRVectorTransformOps
+ MLIRVectorTransforms
+ iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
+ iree::compiler::Codegen::Dialect::IREECodegenDialect
+ iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
+ iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
+ iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
+ iree::compiler::Utils
+ PUBLIC
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp
new file mode 100644
index 0000000..ec663c6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp
@@ -0,0 +1,108 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree/compiler/Codegen/Common/PassDetail.h"
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
+#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
+#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+void registerTransformDialectTranslationDependentDialects(
+ DialectRegistry ®istry) {
+ // TODO: this is only necessary to make registry subset happy when running
+ // the lowering to LLVM. The lowering should be changed to stop using the
+ // nested pass manager and this will go away.
+
+ // clang-format off
+ registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
+ mlir::iree_compiler::IREE::Flow::FlowDialect,
+ mlir::iree_compiler::IREE::Codegen::IREECodegenDialect,
+ arith::ArithDialect,
+ affine::AffineDialect,
+ bufferization::BufferizationDialect,
+ func::FuncDialect,
+ gpu::GPUDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ scf::SCFDialect,
+ tensor::TensorDialect,
+ transform::TransformDialect,
+ vector::VectorDialect
+ // clang-format on
+ >();
+
+ // TODO: these should be registered by the extension instead, but there is
+ // no support for it in core currently.
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+
+ registry.addExtensions<
+ mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
+ transform_ext::StructuredTransformOpsExtension>();
+ iree_compiler::registerTransformDialectCommonExtension(registry);
+ iree_compiler::registerTransformDialectFlowExtension(registry);
+ iree_compiler::registerTransformDialectLLVMCPUExtension(registry);
+ iree_compiler::registerTransformDialectLLVMGPUExtension(registry);
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
new file mode 100644
index 0000000..da0f5eb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
@@ -0,0 +1,226 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/PassDetail.h"
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/UserConfig.h"
+#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-codegen-materialize-library-calls"
+
+namespace mlir {
+namespace iree_compiler {
+
+llvm::cl::opt<std::string> clCodegenTransformDialectTestName(
+ "iree-codegen-use-transform-dialect-strategy",
+ llvm::cl::desc(
+ "Broadcasts the given transform dialect strategy specification to all"
+ "dispatches. Supports two modes; a path to the MLIR file containing a"
+ "transform dialect specification to apply, and a symbol reference to"
+ "load from a library of transform specs (@library_call)"),
+ llvm::cl::init(""));
+
+llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName(
+ "iree-codegen-transform-dialect-library",
+ llvm::cl::desc(
+ "File path to a module containing a library of transform dialect"
+ "strategies"),
+ llvm::cl::init(""));
+
+namespace {
+
+static const char kTranslationInfoAttrName[] = "translation_info";
+
+static void createEmptyTransformStrategy(ModuleOp innerModule) {
+ Location loc = innerModule.getLoc();
+ OpBuilder b = OpBuilder::atBlockEnd(innerModule.getBody());
+ auto topLevelTransformModule = b.create<ModuleOp>(loc);
+ Region &topLevelTransformRegion = topLevelTransformModule.getBodyRegion();
+ b.setInsertionPointToStart(&topLevelTransformRegion.front());
+ auto anyOpType = transform::AnyOpType::get(b.getContext());
+
+ // Create the include for the named sequence with the expectation that the
+ // external definition will be linked in later.
+ auto sequence = b.create<transform::SequenceOp>(
+ loc, TypeRange{}, transform::FailurePropagationMode::Propagate, anyOpType,
+ [&](OpBuilder &b, Location loc, Value variantH) {
+ b.create<transform::PrintOp>(loc, variantH);
+ b.create<transform::YieldOp>(loc);
+ });
+ (void)sequence;
+}
+
+struct MaterializeUserConfigsPass
+ : public MaterializeUserConfigsBase<MaterializeUserConfigsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registerTransformDialectTranslationDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+ ModuleOp moduleOp = variantOp.getInnerModule();
+ llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps =
+ getAllEntryPoints(moduleOp);
+ MLIRContext *context = moduleOp.getContext();
+
+ std::optional<ModuleOp> transformLibrary = std::nullopt;
+ if (!clCodegenTransformDialectLibraryFileName.empty()) {
+ auto dialect =
+ context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
+ auto maybeTransformLibrary = dialect->getOrLoadTransformLibraryModule(
+ clCodegenTransformDialectLibraryFileName);
+ if (failed(maybeTransformLibrary)) {
+ return signalPassFailure();
+ }
+ transformLibrary = *maybeTransformLibrary;
+ }
+
+ IREE::Codegen::DispatchLoweringPassPipeline tdPipeline =
+ IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
+ std::optional<IREE::Codegen::TranslationInfoAttr> clTranslationInfo;
+ // Here we always set the pipeline strategy to transform dialect if the
+ // flag is non-empty to ensure we pick the right lowering pipeline in the
+ // event a file path is given.
+ if (!clCodegenTransformDialectTestName.empty()) {
+ clTranslationInfo = IREE::Codegen::TranslationInfoAttr::get(
+ context, tdPipeline,
+ /*softwarePipelineDepth=*/0,
+ /*softwarePipelineStoreStage=*/1,
+ /*codegenSpec=*/clCodegenTransformDialectTestName[0] == '@'
+ ? SymbolRefAttr::get(
+ context, llvm::StringRef(
+ clCodegenTransformDialectTestName.substr(1)))
+ : SymbolRefAttr());
+ }
+
+ std::optional<IREE::Codegen::TranslationInfoAttr> translationInfo;
+ for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+ auto exportOp = exportOps.lookup(funcOp.getName());
+ if (!exportOp) {
+ continue;
+ }
+
+ /// First, apply all user configs.
+ auto res = funcOp.walk([&](Operation *op) {
+ if (auto compilationInfo = getCompilationInfo(op)) {
+ if (failed(setUserConfig(funcOp, op, compilationInfo))) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ if (res.wasInterrupted()) {
+ moduleOp.emitOpError("error in setting user configuration");
+ return signalPassFailure();
+ }
+
+ /// Let user configs take priority over the global strategy flag.
+ if (IREE::Codegen::TranslationInfoAttr exportedTranslationInfo =
+ getTranslationInfo(exportOp)) {
+ if (translationInfo) {
+ /// Currently codegen is rooted on the variant, meaning every entry
+ /// must go through the same codegen pipeline. For multi-targeting we
+ /// will want to have multiple functions per variant, as well as
+ /// multple exports per variant, meaning eventually the nesting of
+ /// the translation pipeline will need to change to the function, or
+ /// we'll need another level of module op nesting.
+ if (exportedTranslationInfo != translationInfo.value()) {
+ moduleOp.emitOpError(
+ "unhandled compilation of entry point functions with different "
+ "translation info");
+ return signalPassFailure();
+ }
+ } else {
+ translationInfo = exportedTranslationInfo;
+ }
+ } else {
+ if (translationInfo && translationInfo != clTranslationInfo) {
+ moduleOp.emitOpError(
+ "unhandled compilation of entry point functions with translation "
+ "info optionality");
+ return signalPassFailure();
+ }
+ if (clTranslationInfo) {
+ translationInfo = clTranslationInfo;
+ if (failed(setTranslationInfo(funcOp, translationInfo.value()))) {
+ moduleOp.emitOpError("failed to set command line translation info");
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+
+ /// We only need to resolve symbols for transform dialect based strategies.
+ if (!translationInfo ||
+ translationInfo.value().getDispatchLoweringPassPipeline() !=
+ tdPipeline) {
+ return;
+ }
+
+ std::optional<SymbolRefAttr> libraryFunc =
+ translationInfo.value().getCodegenSpec();
+ if (!libraryFunc || *libraryFunc == SymbolRefAttr()) {
+ return;
+ }
+
+ /// If we have a symbol, verify the existence of the symbol within the
+ /// transform library.
+ if (!transformLibrary || !(*transformLibrary) ||
+ !transform::detail::findTransformEntryPoint(
+ variantOp, *transformLibrary, libraryFunc->getLeafReference())) {
+ moduleOp.emitOpError("failed to find transform strategy symbol");
+ return signalPassFailure();
+ }
+
+ // TODO: At this point we could allow the user to (optionally) return a
+ // translation info attribute to use, however there currently isn't a way
+ // upstream to retrieve the results of the named sequence.
+
+ /// Attempt to execute the strategy. symbol (from the flag or otherwise) at
+ /// the same time. Because the strategy is rooted on the variant op, the
+ /// strategy can change the translation info on the exports if needed, else
+ /// back to default IREE codegen.
+ if (failed(transform::applyTransformNamedSequence(
+ variantOp, *transformLibrary, options.enableExpensiveChecks(true),
+ libraryFunc->getLeafReference()))) {
+ return signalPassFailure();
+ }
+
+ // Re-retrieve the export ops and mark all exports with unchanged
+ // translation info as un-translated.
+ // TODO: Currently this is the only way to "fall back" to codegen. If the
+ // user wants to do all of codegen themselves they can set a `None`
+ // pipeline.
+ exportOps = getAllEntryPoints(variantOp.getInnerModule());
+ for (auto &it : exportOps) {
+ auto exportOp = it.second;
+ if (getTranslationInfo(exportOp) == translationInfo) {
+ exportOp->removeAttr(kTranslationInfoAttrName);
+ }
+ }
+ }
+
+private:
+ /// Transform interpreter options.
+ transform::TransformOptions options;
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
+createMaterializeUserConfigsPass() {
+ return std::make_unique<MaterializeUserConfigsPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/PassDetail.h b/compiler/src/iree/compiler/Codegen/Common/PassDetail.h
index 72e1466..4b45a23 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PassDetail.h
+++ b/compiler/src/iree/compiler/Codegen/Common/PassDetail.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.cpp b/compiler/src/iree/compiler/Codegen/Common/Passes.cpp
index 6254af0..f8e9571 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.cpp
@@ -12,11 +12,13 @@
namespace iree_compiler {
void addCommonTargetExecutablePreprocessingPasses(OpPassManager &passManager) {
- passManager.addNestedPass<func::FuncOp>(createTypePropagationPass());
- passManager.addPass(createBubbleUpOrdinalOpsPass());
- passManager.addPass(createBufferizeCopyOnlyDispatchesPass());
- passManager.addNestedPass<func::FuncOp>(
+ OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
+ nestedModulePM.addNestedPass<func::FuncOp>(createTypePropagationPass());
+ nestedModulePM.addPass(createBubbleUpOrdinalOpsPass());
+ nestedModulePM.addPass(createBufferizeCopyOnlyDispatchesPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createDecomposeSoftmaxPass());
+ passManager.addPass(createMaterializeUserConfigsPass());
}
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 47c7bde..eedd9e3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -22,6 +22,11 @@
namespace mlir {
namespace iree_compiler {
+/// Function to register all dependent dialects for Transform Dialect based
+/// passes.
+void registerTransformDialectTranslationDependentDialects(
+ DialectRegistry ®istry);
+
/// Passes that are done on all backends before target-specific code-generation
/// kicks in.
void addCommonTargetExecutablePreprocessingPasses(OpPassManager &passManager);
@@ -192,6 +197,10 @@
/// Creates a pass to convert memref.copy to linalg op.
std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();
+/// Extracts lowering configs and translation info from user configs.
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
+createMaterializeUserConfigsPass();
+
/// Pass to optimize vector transfer_read and transfer_write.
std::unique_ptr<OperationPass<func::FuncOp>>
createOptimizeVectorTransferPass(bool flatten = false,
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 40edf1c..9abfe85 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -331,6 +331,15 @@
let constructor = "mlir::iree_compiler::createLowerUKernelOpsToCallsPass()";
}
+def MaterializeUserConfigs :
+ Pass<"iree-codegen-materialize-user-configs", "IREE::HAL::ExecutableVariantOp"> {
+ let summary = "Sets the lowering configs and translation info from user configs";
+ let constructor = "mlir::iree_compiler::createMaterializeUserConfigsPass()";
+ let dependentDialects = [
+ "transform::TransformDialect"
+ ];
+}
+
def MemrefCopyToLinalgPass :
Pass<"iree-codegen-memrefcopy-to-linalg", "func::FuncOp"> {
let summary = "Convert memref.copy to linalg op";
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index dea6786..1b95c7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -4,45 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
-#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
-#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
-#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
-#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
-#include "mlir/Dialect/PDL/IR/PDL.h"
-#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
-#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
-#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
-#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@@ -59,58 +25,12 @@
iree_compiler::TransformDialectInterpreterBase> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- // TODO: this is only necessary to make registry subset happy when running
- // the lowering to LLVM. The lowering should be changed to stop using the
- // nested pass manager and this will go away.
-
- // clang-format off
- registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
- mlir::iree_compiler::IREE::Flow::FlowDialect,
- mlir::iree_compiler::IREE::Codegen::IREECodegenDialect,
- arith::ArithDialect,
- affine::AffineDialect,
- bufferization::BufferizationDialect,
- func::FuncDialect,
- gpu::GPUDialect,
- linalg::LinalgDialect,
- LLVM::LLVMDialect,
- pdl::PDLDialect,
- pdl_interp::PDLInterpDialect,
- scf::SCFDialect,
- tensor::TensorDialect,
- transform::TransformDialect,
- vector::VectorDialect
- // clang-format on
- >();
-
- // TODO: these should be registered by the extension instead, but there is
- // no support for it in core currently.
- arith::registerBufferizableOpInterfaceExternalModels(registry);
- linalg::registerBufferizableOpInterfaceExternalModels(registry);
- scf::registerBufferizableOpInterfaceExternalModels(registry);
- bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ mlir::iree_compiler::registerTransformDialectTranslationDependentDialects(
registry);
- tensor::registerBufferizableOpInterfaceExternalModels(registry);
- tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
- vector::registerBufferizableOpInterfaceExternalModels(registry);
-
- registry.addExtensions<
- mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
- transform_ext::StructuredTransformOpsExtension>();
- iree_compiler::registerTransformDialectCommonExtension(registry);
- iree_compiler::registerTransformDialectFlowExtension(registry);
- iree_compiler::registerTransformDialectLLVMCPUExtension(registry);
- iree_compiler::registerTransformDialectLLVMGPUExtension(registry);
- affine::registerTransformDialectExtension(registry);
- bufferization::registerTransformDialectExtension(registry);
- gpu::registerTransformDialectExtension(registry);
- linalg::registerTransformDialectExtension(registry);
- memref::registerTransformDialectExtension(registry);
- scf::registerTransformDialectExtension(registry);
- tensor::registerTransformDialectExtension(registry);
- vector::registerTransformDialectExtension(registry);
}
+ // We don't register libraries here because we expect them to be pre-loaded
+ // much earlier on in the compiler pipeline.
TransformDialectInterpreterPass(
StringRef transformFileName = StringRef(),
StringRef debugPayloadRootTag = StringRef(),
@@ -126,13 +46,36 @@
namespace mlir {
namespace iree_compiler {
+
+extern llvm::cl::opt<std::string> clCodegenTransformDialectTestName;
+static llvm::cl::opt<std::string> clCodegenTransformDialectDebugPayloadTag(
+ "iree-codegen-transform-dialect-debug-payload-tag",
+ llvm::cl::desc("tag attribute value for the transform dialect interpreter "
+ "payload root operation"),
+ llvm::cl::init(""));
+static llvm::cl::opt<std::string> clCodegenTransformDialectDebugTransformTag(
+ "iree-codegen-transform-dialect-debug-transform-tag",
+ llvm::cl::desc(
+ "tag attribute value for the transform dialect transform op container"),
+ llvm::cl::init(""));
+
/// Create a Transform dialect interpreter pass.
std::unique_ptr<Pass>
createTransformDialectInterpreterPass(llvm::StringRef transformFileName,
llvm::StringRef debugPayloadRootTag,
llvm::StringRef debugTransformRootTag) {
+ // If the strategy filename is prefixed with `@`, it refers to a library
+ // call.
+ std::string clFileName = !clCodegenTransformDialectTestName.empty() &&
+ clCodegenTransformDialectTestName[0] != '@'
+ ? clCodegenTransformDialectTestName
+ : std::string();
return std::make_unique<TransformDialectInterpreterPass>(
- transformFileName, debugPayloadRootTag, debugTransformRootTag);
+ transformFileName.empty() ? clFileName : transformFileName,
+ debugPayloadRootTag.empty() ? clCodegenTransformDialectDebugPayloadTag
+ : debugPayloadRootTag,
+ debugTransformRootTag.empty() ? clCodegenTransformDialectDebugTransformTag
+ : debugTransformRootTag);
}
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/BUILD.bazel
index 0d43fae..2cf1239 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/BUILD.bazel
@@ -44,6 +44,7 @@
srcs = [
"IREECodegenAttrs.cpp",
"IREECodegenDialect.cpp",
+ "IREECodegenLibraryManager.cpp",
"IREECodegenOps.cpp",
"UKernelOps.cpp",
],
@@ -84,6 +85,8 @@
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/CMakeLists.txt
index 4f15431..b1f58e9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CMakeLists.txt
@@ -32,6 +32,7 @@
SRCS
"IREECodegenAttrs.cpp"
"IREECodegenDialect.cpp"
+ "IREECodegenLibraryManager.cpp"
"IREECodegenOps.cpp"
"UKernelOps.cpp"
DEPS
@@ -48,6 +49,8 @@
MLIRMemRefDialect
MLIRParser
MLIRSupport
+ MLIRTransformDialect
+ MLIRTransformDialectTransforms
iree::builtins::ukernel::exported_bits
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::Utils
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.cpp
index 602a2be..2ae0317 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
#define GET_ATTRDEF_CLASSES
@@ -59,11 +60,12 @@
TranslationInfoAttr TranslationInfoAttr::get(
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
- unsigned softwarePipelineDepth, unsigned softwarePipelineStoreStage) {
+ unsigned softwarePipelineDepth, unsigned softwarePipelineStoreStage,
+ SymbolRefAttr codegenSpec) {
auto pipelineAttr =
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
return get(context, pipelineAttr, softwarePipelineDepth,
- softwarePipelineStoreStage);
+ softwarePipelineStoreStage, codegenSpec);
}
DispatchLoweringPassPipeline
@@ -74,7 +76,8 @@
LogicalResult TranslationInfoAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline,
- unsigned softwarePipelineDepth, unsigned softwarePipelineStoreStage) {
+ unsigned softwarePipelineDepth, unsigned softwarePipelineStoreStage,
+ SymbolRefAttr codegenSpec) {
if (!passPipeline) {
return emitError() << "missing pass pipeline specification";
}
@@ -83,6 +86,13 @@
return emitError() << "invalid pass pipeline value : "
<< stringifyEnum(passPipeline.getValue());
}
+ auto tdPassPipeline =
+ IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
+ if (codegenSpec && passPipelineValue != tdPassPipeline) {
+ return emitError()
+ << "transform dialect codegen spec requires pass pipeline : "
+ << stringifyEnum(tdPassPipeline);
+ }
return success();
}
@@ -291,7 +301,8 @@
if (failed(TranslationInfoAttr::verify(
emitError, translationInfo.getPassPipeline(),
translationInfo.getSoftwarePipelineDepth(),
- translationInfo.getSoftwarePipelineStoreStage()))) {
+ translationInfo.getSoftwarePipelineStoreStage(),
+ translationInfo.getCodegenSpec()))) {
return failure();
}
return success();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
index edb49df..40a722d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
@@ -132,7 +132,8 @@
let assemblyFormat = [{
`<` `` $passPipeline
(`pipeline_depth` `=` $softwarePipelineDepth^)?
- (`store_stage` `=` $softwarePipelineStoreStage^)? `>`
+ (`store_stage` `=` $softwarePipelineStoreStage^)?
+ (`codegen_spec` `=` $codegenSpec^)? `>`
}];
let parameters = (ins
@@ -141,12 +142,15 @@
OptionalParameter<"unsigned",
"The software pipeline depth to be used">:$softwarePipelineDepth,
DefaultValuedParameter<"unsigned", "1",
- "The software pipeline stage to place stores">:$softwarePipelineStoreStage
+ "The software pipeline stage to place stores">:$softwarePipelineStoreStage,
+ OptionalParameter<"SymbolRefAttr",
+ "The symbol pointing to the transform dialect codegen spec to be used">:$codegenSpec
);
let builders = [
AttrBuilder<(ins "DispatchLoweringPassPipeline":$passPipeline,
CArg<"unsigned", "0">:$softwarePipelineDepth,
- CArg<"unsigned", "1">:$softwarePipelineStoreStage)>
+ CArg<"unsigned", "1">:$softwarePipelineStoreStage,
+ CArg<"SymbolRefAttr", "{}">:$codegenSpec)>
];
let extraClassDeclaration = [{
// Returns the lowering pass pipeline set.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.h b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.h
index c0f7db1..abeac9a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.h
@@ -7,8 +7,15 @@
#ifndef IREE_COMPILER_CODEGEN_DIALECT_IREECODEGEN_DIALECT_H_
#define IREE_COMPILER_CODEGEN_DIALECT_IREECODEGEN_DIALECT_H_
+#include <mutex>
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
// clang-format off: must be included after all LLVM/MLIR eaders
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h.inc" // IWYU pragma: keep
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.td b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.td
index b04bb39..1977f0a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenDialect.td
@@ -35,6 +35,22 @@
}];
let extraClassDeclaration = [{
void initializeCodegenAttrs();
+
+ FailureOr<::mlir::ModuleOp>
+ getOrLoadTransformLibraryModule(std::string libraryPath);
+
+ private:
+
+ /// Map containing modules containing symbols, e.g. named sequences, that
+ /// will be executed by the interpreter when used. This is a reflection of the
+ /// library module storage upstream on the transform dialect, but instead we
+ /// manage it here to ensure all required dialects are registered, and so that
+ /// we can handle the loading/caching ourselves.
+ ::llvm::StringMap<::mlir::OwningOpRef<::mlir::ModuleOp>> libraryModules;
+
+ /// Lock to control the updating of the library modules such that we only load
+ /// the module once and can reuse it across all invocations.
+ std::mutex libraryMutex;
}];
let useDefaultAttributePrinterParser = 1;
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenLibraryManager.cpp b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenLibraryManager.cpp
new file mode 100644
index 0000000..c691700
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenLibraryManager.cpp
@@ -0,0 +1,47 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Codegen {
+
+FailureOr<ModuleOp>
+IREECodegenDialect::getOrLoadTransformLibraryModule(std::string libraryPath) {
+ // Acquire a lock on the map that will release once out of scope.
+ std::lock_guard<std::mutex> guard(libraryMutex);
+
+ auto loadedLibrary = libraryModules.find(libraryPath);
+ if (loadedLibrary != libraryModules.end()) {
+ // Check whether the library already failed to load.
+ if (!(loadedLibrary->second) || !(*(loadedLibrary->second))) {
+ return failure();
+ }
+ return *(loadedLibrary->second);
+ }
+
+ OwningOpRef<ModuleOp> mergedParsedLibraries;
+ if (failed(transform::detail::assembleTransformLibraryFromPaths(
+ getContext(), SmallVector<std::string>{libraryPath},
+ mergedParsedLibraries))) {
+ // We update the storage for the library regardless of whether parsing
+ // succeeds so that other threads don't have to retry.
+ OwningOpRef<ModuleOp> emptyLibrary;
+ libraryModules[libraryPath] = std::move(emptyLibrary);
+ return failure();
+ }
+
+ libraryModules[libraryPath] = std::move(mergedParsedLibraries);
+ return *libraryModules[libraryPath];
+}
+
+} // namespace Codegen
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 4a708eb..a48b470 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -10,7 +10,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Common/TileSizeSelection.h"
-#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/TransformStrategies/CPU/Common.h"
@@ -89,26 +88,10 @@
llvm::cl::init(true));
// Non-static options are used in other places.
-llvm::cl::opt<std::string> clCPUCodegenTransformDialectFileName(
- "iree-codegen-llvmcpu-use-transform-dialect",
- llvm::cl::desc(
- "MLIR file containing a transform dialect specification to apply"),
- llvm::cl::init(""));
llvm::cl::opt<bool> clCPUEnableTransformDialectJit(
"iree-codegen-llvmcpu-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(false));
-llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugPayloadTag(
- "iree-codegen-llvmcpu-transform-dialect-debug-payload-tag",
- llvm::cl::desc("tag attribute value for the transform dialect interpreter "
- "payload root operation"),
- llvm::cl::init(""));
-
-llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugTransformTag(
- "iree-codegen-llvmcpu-transform-dialect-debug-transform-tag",
- llvm::cl::desc(
- "tag attribute value for the transform dialect transform op container"),
- llvm::cl::init(""));
using IREE::Codegen::DispatchLoweringPassPipeline;
@@ -2391,15 +2374,6 @@
static LogicalResult
setTranslationInfoAndRootConfig(func::FuncOp entryPointFn,
ArrayRef<Operation *> computeOps) {
- // First check if the operations have a preset pipeline. If the config is
- // preset, do not overwrite it.
- for (auto computeOp : computeOps) {
- if (IREE::Codegen::CompilationInfoAttr compilationInfo =
- getCompilationInfo(computeOp)) {
- return setUserConfig(entryPointFn, computeOp, compilationInfo);
- }
- }
-
// Make sure that lowering_config is not preset on any compute ops.
for (auto computeOp : computeOps) {
if (getLoweringConfig(computeOp))
@@ -2459,18 +2433,6 @@
if (getTranslationInfo(exportOp))
continue;
- // If using the transform dialect with a script file, intercept early.
- if (!clCPUCodegenTransformDialectFileName.empty()) {
- assert(!clCPUEnableTransformDialectJit &&
- "Can't use both transform dialect interpreted and jitted modes");
- auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
- moduleOp.getContext(),
- IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen);
- if (failed(setTranslationInfo(funcOp, translationInfo)))
- return failure();
- continue;
- }
-
// For now pick the default for functions with control flow, cause
// the currently built pipelines dont work so well with control flow.
if (funcOp.getBody().empty() || !llvm::hasSingleElement(funcOp.getBody())) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index f092a47..5007f9a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -244,8 +244,10 @@
hasSMEFeature(target);
if (!testLoweringConfiguration) {
switch (translationInfo.value().getDispatchLoweringPassPipeline()) {
- case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault:
+ // No pipleline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
+ return;
+ case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault:
addCPUDefaultPassPipeline(executableLoweringPipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index b563dba..2efcd8d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -86,14 +86,6 @@
"instrumentation is enabled."),
llvm::cl::init(false)};
-// MLIR file containing a top-level module that specifies the transformations to
-// apply to form dispatch regions.
-// Defined externally in KernelDispatch.cpp to control the codegen pass
-// pipeline.
-extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectFileName;
-extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugPayloadTag;
-extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugTransformTag;
-
//===---------------------------------------------------------------------===//
// Default allocation functions for CPU backend
//===---------------------------------------------------------------------===//
@@ -670,10 +662,7 @@
void addTransformDialectPasses(OpPassManager &passManager) {
// Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(
- clCPUCodegenTransformDialectFileName,
- clCPUCodegenTransformDialectDebugPayloadTag,
- clCPUCodegenTransformDialectDebugTransformTag));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
// schedule once applied.
@@ -764,8 +753,8 @@
void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {
{
+ addCommonTargetExecutablePreprocessingPasses(passManager);
OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
- addCommonTargetExecutablePreprocessingPasses(modulePassManager);
modulePassManager.addNestedPass<func::FuncOp>(
createRematerializeParallelOpsPass());
// TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added way
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
index deac2a0..a064023 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{test-lowering-configuration=true})))' --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmcpu-lower-executable-target{test-lowering-configuration=true})))' --split-input-file %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/peel_and_vectorize.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/peel_and_vectorize.mlir
index 6722802..d8bc363 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/peel_and_vectorize.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/peel_and_vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))' -split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmcpu-lower-executable-target)))' -split-input-file %s | FileCheck %s
// Test peeling + vectorization using CPUDoubleTilingPeelingExpert.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
index 90285a0..8eeacdb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))' --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmcpu-lower-executable-target)))' --split-input-file %s | FileCheck %s
// Check that this dispatch compiles to vectors and that there are no allocas.
// By proxy checks that destination passing style kicked in correctly
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 33dc724..c3b8d2c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -9,7 +9,6 @@
#include <numeric>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h"
@@ -33,29 +32,12 @@
static constexpr StringLiteral kRocmTarget = "rocm";
namespace mlir {
namespace iree_compiler {
-llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName(
- "iree-codegen-llvmgpu-use-transform-dialect",
- llvm::cl::desc(
- "MLIR file containing a transform dialect specification to apply"),
- llvm::cl::init(""));
llvm::cl::opt<bool> clGPUEnableTransformDialectJit(
"iree-codegen-llvmgpu-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(true));
-llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugPayloadTag(
- "iree-codegen-llvmgpu-transform-dialect-debug-payload-tag",
- llvm::cl::desc("tag attribute value for the transform dialect interpreter "
- "payload root operation"),
- llvm::cl::init(""));
-
-llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugTransformTag(
- "iree-codegen-llvmgpu-transform-dialect-debug-transform-tag",
- llvm::cl::desc(
- "tag attribute value for the transform dialect transform op container"),
- llvm::cl::init(""));
-
/// Flag to force using WMMA tensorcore operations.
llvm::cl::opt<bool>
clGPUUseWMMA("iree-codegen-llvmgpu-use-wmma",
@@ -345,7 +327,8 @@
std::move(workgroupTileSizes)); // Workgroup level.
return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, tileSizes, pipeline, workgroupSize,
- /*subgroupSize=*/std::nullopt, softwarePipelineDepth);
+ /*subgroupSize=*/std::nullopt, softwarePipelineDepth,
+ /*softwarePipelineStoreStage=*/1);
};
// Infer the MxN size of the matmul based on operands and indexing maps.
auto lhsShape =
@@ -717,29 +700,17 @@
return std::nullopt;
}
-/// Set configuration for reduction transform dialect based strategy.
+/// Set configuration for transform dialect based strategies.
static LogicalResult setTransformDialectConfig(func::FuncOp entryPoint,
Operation *op,
const TargetInfo &targetInfo) {
- if (!clGPUCodegenTransformDialectFileName.empty() &&
- clGPUEnableTransformDialectJit) {
- return entryPoint.emitError()
- << "option clash in transform dialect lowering config: the filename "
- "cannot be provided when the jit option is set";
- }
-
- if (!clGPUEnableTransformDialectJit &&
- clGPUCodegenTransformDialectFileName.empty()) {
+ if (!clGPUEnableTransformDialectJit) {
return failure();
}
- // Transform script file provided, use it.
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
entryPoint.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen);
- if (!clGPUCodegenTransformDialectFileName.empty()) {
- return setTranslationInfo(entryPoint, translationInfo);
- }
// TODO: unify the target informations into one structure.
iree_compiler::gpu::GPUModel gpuModel;
@@ -1168,12 +1139,6 @@
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
Operation *computeOp) {
TargetInfo targetInfo = getTargetInfo(entryPointFn);
- if (IREE::Codegen::CompilationInfoAttr compilationInfo =
- getCompilationInfo(computeOp)) {
- // If the op already has a lowering config coming from the IR use this and
- // bypass the heuristic.
- return setUserConfig(entryPointFn, computeOp, compilationInfo);
- }
// First try to see if there is a transform dialect configuration existing.
if (succeeded(
setTransformDialectConfig(entryPointFn, computeOp, targetInfo))) {
@@ -1195,17 +1160,6 @@
}
}
- // If using the transform dialect, call the proper pipeline.
- assert((clGPUCodegenTransformDialectFileName.empty() ||
- !clGPUEnableTransformDialectJit) &&
- "Can't use both transform dialect interpreted and jitted modes");
- if (clGPUCodegenTransformDialectFileName.size() > 0) {
- auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
- entryPointFn.getContext(),
- IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen);
- return setTranslationInfo(entryPointFn, translationInfo);
- }
-
if (auto fftOp = dyn_cast<IREE::LinalgExt::FftOp>(computeOp)) {
return setFftConfig(entryPointFn, fftOp);
}
@@ -1222,6 +1176,24 @@
return setRootDefaultConfig(entryPointFn, computeOp);
}
+// Propogate the configuration to the other ops.
+// TODO(ravishankarm, thomasraoux): This is a very specific use (and
+// fragile). In general, this should not be needed. Things are already tiled
+// and distributed. The rest of the compilation must be structured to either
+// use `TileAndFuse` or they are independent configurations that are
+// determined based on the op.
+static void propagateLoweringConfig(Operation *rootOperation,
+ SmallVector<Operation *> computeOps) {
+ if (IREE::Codegen::LoweringConfigAttr config =
+ getLoweringConfig(rootOperation)) {
+ for (auto op : computeOps) {
+ if (op == rootOperation)
+ continue;
+ setLoweringConfig(op, config);
+ }
+ }
+}
+
namespace mlir {
namespace iree_compiler {
@@ -1233,10 +1205,20 @@
auto exportOp = exportOps.lookup(funcOp.getName());
if (!exportOp)
continue;
- if (getTranslationInfo(exportOp))
- continue;
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
+ if (getTranslationInfo(exportOp)) {
+ // Currently LLVMGPU requires propagation of user lowering configs.
+ for (auto op : computeOps) {
+ if (getLoweringConfig(op)) {
+ propagateLoweringConfig(op, computeOps);
+ break;
+ }
+ }
+ continue;
+ }
+
Operation *rootOperation = nullptr;
+
// Find the root operation. linalg.generic and linalg.fill are not root
// operations if there are other compute operations present.
for (Operation *op : llvm::reverse(computeOps)) {
@@ -1270,20 +1252,7 @@
if (failed(setRootConfig(funcOp, rootOperation)))
continue;
- // Propogate the configuration to the other ops.
- // TODO(ravishankarm, thomasraoux): This is a very specific use (and
- // fragile). In general, this should not be needed. Things are already tiled
- // and distributed. The rest of the compilation must be structured to either
- // use `TileAndFuse` or they are independent configurations that are
- // determined based on the op.
- if (IREE::Codegen::LoweringConfigAttr config =
- getLoweringConfig(rootOperation)) {
- for (auto op : computeOps) {
- if (op == rootOperation)
- continue;
- setLoweringConfig(op, config);
- }
- }
+ propagateLoweringConfig(rootOperation, computeOps);
}
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index ad59a5a..3437524 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -187,6 +187,9 @@
case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
addGPUTransformDialectPasses(executableLoweringPipeline);
break;
+ // no pipeline specified, nothing to do.
+ case IREE::Codegen::DispatchLoweringPassPipeline::None:
+ return;
default:
variantOp.emitOpError("Unsupported pipeline on GPU target.");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 6cfa3bc..048c2fd 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -579,17 +579,12 @@
}
}
-extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName;
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugPayloadTag;
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugTransformTag;
void addGPUTransformDialectPasses(OpPassManager &passManager) {
- // Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(
- clGPUCodegenTransformDialectFileName,
- clGPUCodegenTransformDialectDebugPayloadTag,
- clGPUCodegenTransformDialectDebugTransformTag));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
@@ -599,7 +594,7 @@
}
void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
- addCommonTargetExecutablePreprocessingPasses(pm.nest<ModuleOp>());
+ addCommonTargetExecutablePreprocessingPasses(pm);
pm.addPass(createLLVMGPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
//===--------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 73d24ba..19ceac5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt %s --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%s | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%s | \
// RUN: FileCheck --check-prefix=CHECK %s
hal.executable @_attention_dispatch_0 {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index f7efff8..c6b7399 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false %s | FileCheck %s
// Transform dialect attributes are tested separately.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
index 7acd071..738c009 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
@@ -1,11 +1,11 @@
-// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))" \
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))" \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/transform_dialect_codegen_bufferize_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/transform_dialect_codegen_bufferize_spec.mlir | \
// RUN: FileCheck %s
-// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))" \
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))" \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/transform_dialect_codegen_foreach_to_gpu_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/transform_dialect_codegen_foreach_to_gpu_spec.mlir | \
// RUN: FileCheck %s --check-prefix=FOREACH-TO-GPU
#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}>]}>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index c02f09c..afb799b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h"
@@ -44,12 +43,6 @@
namespace mlir {
namespace iree_compiler {
-llvm::cl::opt<std::string> clSPIRVTransformDialectFileName(
- "iree-spirv-use-transform-dialect",
- llvm::cl::desc(
- "MLIR file containing a transform dialect specification to apply"),
- llvm::cl::init(""));
-
llvm::cl::opt<bool> clSPIRVEnableTransformDialectJit(
"iree-spirv-enable-transform-dialect-jit",
llvm::cl::desc("Enable the usage of the transform dialect JIT"),
@@ -1617,8 +1610,7 @@
static LogicalResult
setTransformDialectConfig(func::FuncOp entryPoint, Operation *op,
const spirv::TargetEnv &targetEnv) {
- if (!clSPIRVEnableTransformDialectJit &&
- clSPIRVTransformDialectFileName.empty()) {
+ if (!clSPIRVEnableTransformDialectJit) {
return failure();
}
@@ -1626,12 +1618,6 @@
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
context, CodeGenPipeline::TransformDialectCodegen);
- // Prefer a transform script file if provided.
- if (!clSPIRVTransformDialectFileName.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n");
- return setTranslationInfo(entryPoint, translationInfo);
- }
-
spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
// TODO: unify the target information into one structure.
@@ -1673,13 +1659,6 @@
static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv,
func::FuncOp entryPointFn,
Operation *rootOp) {
- if (IREE::Codegen::CompilationInfoAttr compilationInfo =
- getCompilationInfo(rootOp)) {
- // If the op already has a lowering configuration specified from the
- // original source by the user, then use it directly.
- return setUserConfig(entryPointFn, rootOp, compilationInfo);
- }
-
// First try to see if there is a matching transform dialect configuration.
if (succeeded(setTransformDialectConfig(entryPointFn, rootOp, targetEnv))) {
return success();
@@ -1849,6 +1828,8 @@
auto exportOp = exportOps.lookup(funcOp.getName());
if (!exportOp)
continue;
+ if (getTranslationInfo(exportOp))
+ continue;
if (failed(setConfigForKernel(targetEnv, exportOp, funcOp))) {
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index d23fa08..86e82fc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -263,13 +263,9 @@
spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
}
-extern llvm::cl::opt<std::string> clSPIRVTransformDialectFileName;
-
void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
- // Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(
- clSPIRVTransformDialectFileName));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
@@ -659,7 +655,7 @@
//===----------------------------------------------------------------------===//
void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) {
- addCommonTargetExecutablePreprocessingPasses(pm.nest<ModuleOp>());
+ addCommonTargetExecutablePreprocessingPasses(pm);
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(
createSPIRVGeneralizeNamedOpsPass());
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 2e9f5a3..1613b2d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -194,6 +194,9 @@
case CodeGenPipeline::TransformDialectCodegen:
addSPIRVTransformDialectPassPipeline(pipeline);
break;
+ // No pipeline specified, nothing to do.
+ case CodeGenPipeline::None:
+ return;
default:
variantOp.emitOpError("Unsupported pipeline on GPU target.");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index 24672b7..f6c559f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -57,7 +57,6 @@
"pipeline_reduction_subgroup.mlir",
"pipeline_sub_byte_dequant.mlir",
"set_transform_strategy.mlir",
- "set_transform_strategy_from_file.mlir",
"tile_and_distribute.mlir",
"tile_and_distribute_scatter.mlir",
"tile_and_distribute_sort.mlir",
@@ -82,9 +81,6 @@
],
),
cfg = "//compiler:lit.cfg.py",
- data = [
- "transform_dialect_dummy_spec.mlir",
- ],
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 5f53e55..dd267bf 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -53,7 +53,6 @@
"pipeline_reduction_subgroup.mlir"
"pipeline_sub_byte_dequant.mlir"
"set_transform_strategy.mlir"
- "set_transform_strategy_from_file.mlir"
"tile_and_distribute.mlir"
"tile_and_distribute_scatter.mlir"
"tile_and_distribute_sort.mlir"
@@ -74,8 +73,6 @@
TOOLS
FileCheck
iree-opt
- DATA
- transform_dialect_dummy_spec.mlir
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
index 8e262ff..045bc3e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true})))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-spirv-lower-executable-target-pass{test-lowering-configuration=true})))' %s | FileCheck %s
#compilation = #iree_codegen.compilation_info<
lowering_config = <tile_sizes = [[128, 256], [16, 16]]>,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
index 243ec0e..770adde 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true})))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-spirv-lower-executable-target-pass{test-lowering-configuration=true})))' \
// RUN: --verify-diagnostics --split-input-file %s
#compilation = #iree_codegen.compilation_info<
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
index fe77ee1..1eea6f6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
#compilation = #iree_codegen.compilation_info<
lowering_config = <tile_sizes = [[32, 128, 1, 32]]>,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
index edb222a..2462ce3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
// Verify pipelining + multi-buffering.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy_from_file.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy_from_file.mlir
deleted file mode 100644
index bf6487a..0000000
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy_from_file.mlir
+++ /dev/null
@@ -1,47 +0,0 @@
-// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))" --iree-spirv-use-transform-dialect=%p/transform_dialect_dummy_spec.mlir | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable private @copy_f32 {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 16>>
- }>) {
- hal.executable.export public @copy_f32 ordinal(0) layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- // CHECK: IR printer:
- func.func @copy_f32() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2xf32>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
- %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2xf32>> -> tensor<2x2xf32>
- %3 = tensor.empty() : tensor<2x2xf32>
- %4 = linalg.generic {
- indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
- ins(%2 : tensor<2x2xf32>) outs(%3 : tensor<2x2xf32>) {
- ^bb0(%arg0: f32, %arg1: f32):
- %5 = math.sqrt %arg0 : f32
- linalg.yield %5 : f32
- } -> tensor<2x2xf32>
- flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
- return
- }
- }
- // CHECK-COUNT-2: vector.transfer_read
- // CHECK-COUNT-2: math.sqrt
- // CHECK-COUNT-2: vector.transfer_write
- }
-}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
deleted file mode 100644
index 1ba39f4..0000000
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: iree-opt %s
-
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
- print %arg0 : !transform.any_op
-}
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index 79f9da5..d8c4d71 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -34,7 +34,7 @@
// ---------------------------------------------------------------------------
// Tensor-level optimization, kernel dispatch and lower to buffers.
// ---------------------------------------------------------------------------
- addCommonTargetExecutablePreprocessingPasses(passManager.nest<ModuleOp>());
+ addCommonTargetExecutablePreprocessingPasses(passManager);
passManager.nest<ModuleOp>().addNestedPass<func::FuncOp>(
createCPUMaterializeEncodingPass());
// TODO: Remove the following pass the plumb support for #hal.descriptor_type
diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt
index 057c48d..2d7691e 100644
--- a/samples/CMakeLists.txt
+++ b/samples/CMakeLists.txt
@@ -19,4 +19,5 @@
add_subdirectory(py_custom_module)
add_subdirectory(simple_embedding)
add_subdirectory(static_library)
+add_subdirectory(transform_dialect)
add_subdirectory(variables_and_state)
diff --git a/samples/transform_dialect/CMakeLists.txt b/samples/transform_dialect/CMakeLists.txt
new file mode 100644
index 0000000..2b807a1
--- /dev/null
+++ b/samples/transform_dialect/CMakeLists.txt
@@ -0,0 +1,21 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV)
+ return()
+endif()
+
+iree_lit_test_suite(
+ NAME
+ transform_example
+ SRCS
+ "example_module.mlir"
+ DATA
+ "transform_library.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+)
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
new file mode 100644
index 0000000..2b9275a
--- /dev/null
+++ b/samples/transform_dialect/example_module.mlir
@@ -0,0 +1,134 @@
+// Source IR for the following. Skips dispatch formation to isolate testing to
+// codegen.
+//
+// !A_size = tensor<16x5xf32>
+// !B_size = tensor<5x16xf32>
+// !C_size = tensor<16x16xf32>
+// !O_size = tensor<16xf32>
+//
+// module {
+// func.func @example_module(%A : !A_size, %B : !B_size, %C : !C_size) -> !O_size {
+// %0 = linalg.add ins(%A, %A : !A_size, !A_size)
+// outs(%A : !A_size) -> !A_size
+// %1 = linalg.matmul ins(%0, %B : !A_size, !B_size)
+// outs(%C : !C_size) -> !C_size
+// %empty = tensor.empty() : !O_size
+// %2 = linalg.reduce
+// ins(%1 : !C_size)
+// outs(%empty : !O_size)
+// dimensions = [1]
+// (%in: f32, %out: f32) {
+// %3 = arith.addf %out, %in: f32
+// linalg.yield %3: f32
+// }
+// return %2 : !O_size
+// }
+// }
+
+#target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>
+
+module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>], legacy_sync}>]} {
+ hal.executable private @example_module_dispatch_0 {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
+ hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout(
+ #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @example_module_dispatch_0_generic_80_f32() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<80xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<80xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [80], strides = [1] : !flow.dispatch.tensor<readonly:tensor<80xf32>> -> tensor<80xf32>
+ %3 = tensor.empty() : tensor<80xf32>
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%2 : tensor<80xf32>) outs(%3 : tensor<80xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %in, %in : f32
+ linalg.yield %5 : f32
+ } -> tensor<80xf32>
+ flow.dispatch.tensor.store %4, %1, offsets = [0], sizes = [80], strides = [1] : tensor<80xf32> -> !flow.dispatch.tensor<writeonly:tensor<80xf32>>
+ return
+ }
+ }
+ }
+ }
+ hal.executable private @example_module_dispatch_1 {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
+ hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout(
+ #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @example_module_dispatch_1_matmul_16x16x5_f32() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x5xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<5x16xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<16x16xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 5], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16x5xf32>> -> tensor<16x5xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [5, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<5x16xf32>> -> tensor<5x16xf32>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<16x16xf32>> -> tensor<16x16xf32>
+ %6 = linalg.matmul ins(%3, %4 : tensor<16x5xf32>, tensor<5x16xf32>) outs(%5 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor<readwrite:tensor<16x16xf32>>
+ return
+ }
+ }
+ }
+ }
+ hal.executable private @example_module_dispatch_2 {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
+ hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout(
+ #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @example_module_dispatch_2_generic_16x16_f32() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16xf32>> -> tensor<16x16xf32>
+ %3 = tensor.empty() : tensor<16xf32>
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<16x16xf32>) outs(%3 : tensor<16xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %out, %in : f32
+ linalg.yield %5 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.tensor.store %4, %1, offsets = [0], sizes = [16], strides = [1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ return
+ }
+ }
+ }
+ }
+}
+
+/// We test first with threading off so that the printers are legible.
+// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
+// RUN: --iree-codegen-use-transform-dialect-strategy=@transform_main \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN: --compile-from=executable-sources \
+// RUN: --compile-to=executable-targets \
+// RUN: --mlir-disable-threading | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
+
+// CODEGEN-PRINTER: IR printer: Setting matmul strategy to default top-level
+// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @transform_main
+// CODEGEN-PRINTER: IR printer: Setting reduce strategy to base vectorize top-level
+// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<SPIRVBaseVectorize>, workgroup_size = [16 : index, 1 : index, 1 : index]
+
+/// Then test with threading to make sure it runs
+// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
+// RUN: --iree-codegen-use-transform-dialect-strategy=@transform_main \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN: --compile-from=executable-sources \
+// RUN: --compile-to=executable-targets \
+// RUN: --mlir-disable-threading | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN
+
+// CODEGEN: spirv.func @example_module_dispatch_0_generic_80_f32
+// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
+// CODEGEN: spirv.func @example_module_dispatch_2_generic_16x16_f32
diff --git a/samples/transform_dialect/transform_library.mlir b/samples/transform_dialect/transform_library.mlir
new file mode 100644
index 0000000..3bb75ad
--- /dev/null
+++ b/samples/transform_dialect/transform_library.mlir
@@ -0,0 +1,52 @@
+module attributes { transform.with_named_sequence } {
+ // Print and send it down normal IREE codegen.
+ transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.consumed}) {
+ %1 = transform.structured.generalize %matmul : (!transform.any_op) -> !transform.any_op
+ transform.print {name = "Setting matmul strategy to default"}
+ transform.yield
+ }
+
+ // Send it down subgroup reduce.
+ transform.named_sequence @use_subgroup_reduce(%reduce: !transform.any_op {transform.readonly}) {
+ %variant_op = transform.get_parent_op %reduce {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
+ %lowering_config = transform.param.constant #iree_codegen.lowering_config<tile_sizes = [[8, 0], [1, 0], [0, 0, 4]]> -> !transform.any_param
+ transform.annotate %reduce "lowering_config" = %lowering_config : !transform.any_op, !transform.any_param
+ %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %subgroup_reduce = transform.param.constant #iree_codegen.translation_info<SPIRVBaseVectorize> -> !transform.any_param
+ %workgroup_size = transform.param.constant [16 : index, 1 : index, 1 : index] -> !transform.any_param
+ transform.annotate %exports "translation_info" = %subgroup_reduce : !transform.any_op, !transform.any_param
+ transform.annotate %exports "workgroup_size" = %workgroup_size : !transform.any_op, !transform.any_param
+ transform.print {name = "Setting reduce strategy to base vectorize"}
+ transform.yield
+ }
+
+ //===------------------------------------------------------===
+ // Matchers
+ //===------------------------------------------------------===
+ transform.named_sequence @match_matmul(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+ transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
+ transform.yield %matmul : !transform.any_op
+ }
+
+ transform.named_sequence @match_reduce(%reduce: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+ transform.match.operation_name %reduce ["linalg.generic"] : !transform.any_op
+ %matched = transform.match.structured failures(propagate) %reduce : (!transform.any_op) -> (!transform.any_op) {
+ ^bb1(%arg1: !transform.any_op):
+ %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+ %rank = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param<i64>
+ transform.match.param.cmpi eq %rank, %c2 : !transform.param<i64>
+
+ transform.match.structured.dim %arg1[-1] {reduction} : !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %matched : !transform.any_op
+ }
+
+ transform.named_sequence @transform_main(%variant_op: !transform.any_op {transform.consumed}) {
+ transform.foreach_match in %variant_op
+ @match_matmul -> @custom_matmul,
+ @match_reduce -> @use_subgroup_reduce
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
diff --git a/tests/e2e/linalg_transform/linalg_transform.mlir b/tests/e2e/linalg_transform/linalg_transform.mlir
index 7ad4ae2..796ec22 100644
--- a/tests/e2e/linalg_transform/linalg_transform.mlir
+++ b/tests/e2e/linalg_transform/linalg_transform.mlir
@@ -2,7 +2,7 @@
/// Specify the dispatch region formation with the transform dialect.
// R-UN: --iree-flow-dispatch-use-transform-dialect=%p/transform_dialect_dispatch_spec.mlir \
/// Specify the codegen strategy with the transform dialect.
-// R-UN: --iree-codegen-llvmcpu-use-transform-dialect=%p/transform_dialect_codegen_spec.mlir \
+// R-UN: --iree-codegen-use-transform-dialect-strategy=%p/transform_dialect_codegen_spec.mlir \
// R-UN: %s | FileCheck %s
diff --git a/tests/transform_dialect/cpu/BUILD.bazel b/tests/transform_dialect/cpu/BUILD.bazel
index f33a079..2a7663c 100644
--- a/tests/transform_dialect/cpu/BUILD.bazel
+++ b/tests/transform_dialect/cpu/BUILD.bazel
@@ -22,14 +22,15 @@
"eltwise_reduction_eltwise.mlir",
"fold_tensor_slice_into_transfer.mlir",
"matmul.mlir",
+ "matmul_library_call.mlir",
],
cfg = "//tests:lit.cfg.py",
# transform dialect spec files are MLIR files that specify a transformation,
# they need to be included as data.
data = [
"attention_codegen_spec.mlir",
- "matmul_codegen_custom_dispatch_formation_spec.mlir",
"matmul_codegen_default_spec.mlir",
+ "transform_library.mlir",
],
tags = [
"noasan",
diff --git a/tests/transform_dialect/cpu/CMakeLists.txt b/tests/transform_dialect/cpu/CMakeLists.txt
index ee72d6a..f674e91 100644
--- a/tests/transform_dialect/cpu/CMakeLists.txt
+++ b/tests/transform_dialect/cpu/CMakeLists.txt
@@ -20,6 +20,7 @@
"eltwise_reduction_eltwise.mlir"
"fold_tensor_slice_into_transfer.mlir"
"matmul.mlir"
+ "matmul_library_call.mlir"
TOOLS
${IREE_LLD_TARGET}
FileCheck
@@ -29,8 +30,8 @@
iree-run-module
DATA
attention_codegen_spec.mlir
- matmul_codegen_custom_dispatch_formation_spec.mlir
matmul_codegen_default_spec.mlir
+ transform_library.mlir
LABELS
"noasan"
"nomsan"
diff --git a/tests/transform_dialect/cpu/attention.mlir b/tests/transform_dialect/cpu/attention.mlir
index 0a43d4d..00591b1 100644
--- a/tests/transform_dialect/cpu/attention.mlir
+++ b/tests/transform_dialect/cpu/attention.mlir
@@ -9,7 +9,7 @@
}
// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/attention_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/attention_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=attention | \
// RUN: FileCheck %s --check-prefixes=EXEC
diff --git a/tests/transform_dialect/cpu/eltwise_reduction_eltwise.mlir b/tests/transform_dialect/cpu/eltwise_reduction_eltwise.mlir
index 63fc65f..d105660 100644
--- a/tests/transform_dialect/cpu/eltwise_reduction_eltwise.mlir
+++ b/tests/transform_dialect/cpu/eltwise_reduction_eltwise.mlir
@@ -49,7 +49,7 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmcpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmcpu-enable-transform-dialect-jit | \
// RUN: FileCheck %s
diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir
index 0c11876..8bc4c7f 100644
--- a/tests/transform_dialect/cpu/matmul.mlir
+++ b/tests/transform_dialect/cpu/matmul.mlir
@@ -15,8 +15,8 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))' \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_default_spec.mlir | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmcpu-lower-executable-target)))' \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/matmul_codegen_default_spec.mlir | \
// RUN: FileCheck %s --check-prefixes=CODEGEN-DEFAULT
// CODEGEN-DEFAULT: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
@@ -25,7 +25,7 @@
// CODEGEN-DEFAULT: hal.return %[[C2]], %[[C1]], %[[C1]]
// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_default_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/matmul_codegen_default_spec.mlir | \
// RUN: iree-run-module --module=- --function=matmul_static \
// RUN: --input="3x5xf32=1" \
// RUN: --input="5x3xf32=2" \
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
deleted file mode 100644
index 8bb09df..0000000
--- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
+++ /dev/null
@@ -1,33 +0,0 @@
-// RUN: iree-opt %s
-
-transform.sequence failures(propagate) {
-^bb1(%variant_op: !transform.any_op):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
- %tiled_generic, %forall =
- transform.structured.tile_using_forall %0 num_threads [2]
- // TODO: IREE needs own workgroup mapping attribute.
- ( mapping = [#gpu.block<x>] )
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall
- : (!transform.any_op) -> ()
-
- // Canonicalization/CSE is needed before bufferization otherwise unnecessary
- // allocs will be created.
- %func_op = transform.structured.match ops{["func.func"]} in %variant_op
- : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_cse %func_op : !transform.any_op
- %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
- : (!transform.any_op) -> !transform.any_op
- transform.iree.forall_to_workgroup %memref_func : (!transform.any_op) -> ()
-
- // CSE is needed on the workgroup_count region to pass this particular test.
- transform.iree.apply_cse %variant_op_3 : !transform.any_op
-}
diff --git a/tests/transform_dialect/cpu/matmul_library_call.mlir b/tests/transform_dialect/cpu/matmul_library_call.mlir
new file mode 100644
index 0000000..5dd24db
--- /dev/null
+++ b/tests/transform_dialect/cpu/matmul_library_call.mlir
@@ -0,0 +1,35 @@
+
+!A_size = tensor<3x5xf32>
+!B_size = tensor<5x3xf32>
+!C_size = tensor<3x3xf32>
+
+module {
+ func.func @matmul_static(
+ %A : !A_size, %B : !B_size, %C : !C_size) -> !C_size {
+ %0 = linalg.matmul ins(%A, %B : !A_size, !B_size)
+ outs(%C : !C_size) -> !C_size
+ return %0 : !C_size
+ }
+}
+
+// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
+// RUN: --iree-codegen-use-transform-dialect-strategy=@custom_matmul \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN: --compile-to=executable-targets | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN-DEFAULT
+
+// CODEGEN-DEFAULT: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
+// CODEGEN-DEFAULT: %[[C2:.+]] = arith.constant 2 : index
+// CODEGEN-DEFAULT: %[[C1:.+]] = arith.constant 1 : index
+// CODEGEN-DEFAULT: hal.return %[[C2]], %[[C1]], %[[C1]]
+
+// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN: --iree-codegen-use-transform-dialect-strategy=@custom_matmul | \
+// RUN: iree-run-module --module=- --function=matmul_static \
+// RUN: --input="3x5xf32=1" \
+// RUN: --input="5x3xf32=2" \
+// RUN: --input="3x3xf32=42" | \
+// RUN: FileCheck %s --check-prefixes=EXEC
+
+// EXEC: 3x3xf32=[52 52 52][52 52 52][52 52 52]
diff --git a/tests/transform_dialect/cpu/transform_library.mlir b/tests/transform_dialect/cpu/transform_library.mlir
new file mode 100644
index 0000000..b390561
--- /dev/null
+++ b/tests/transform_dialect/cpu/transform_library.mlir
@@ -0,0 +1,36 @@
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @custom_matmul(%variant_op: !transform.any_op {transform.consumed}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+
+ %tiled_generic, %forall =
+ transform.structured.tile_using_forall %0 num_threads [2]
+ // TODO: IREE needs own workgroup mapping attribute.
+ ( mapping = [#gpu.block<x>] )
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall
+ : (!transform.any_op) -> ()
+
+ // Canonicalization/CSE is needed before bufferization otherwise unnecessary
+ // allocs will be created.
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
+ %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
+ %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
+ : (!transform.any_op) -> !transform.any_op
+ transform.iree.forall_to_workgroup %memref_func : (!transform.any_op) -> ()
+
+ // CSE is needed on the workgroup_count region to pass this particular test.
+ transform.iree.apply_cse %variant_op_3 : !transform.any_op
+ %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+ %none_attr = transform.param.constant #iree_codegen.translation_info<None> -> !transform.any_param
+ transform.annotate %exports "translation_info" = %none_attr : !transform.any_op, !transform.any_param
+ transform.yield
+ }
+}
diff --git a/tests/transform_dialect/cuda/double_mma_layout_analysis.mlir b/tests/transform_dialect/cuda/double_mma_layout_analysis.mlir
index a52808a..93b143c 100644
--- a/tests/transform_dialect/cuda/double_mma_layout_analysis.mlir
+++ b/tests/transform_dialect/cuda/double_mma_layout_analysis.mlir
@@ -15,7 +15,7 @@
// RUN: --iree-hal-cuda-llvm-target-arch=sm_80 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/double_mma_layout_analysis_dispatch_spec.mlir \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/double_mma_layout_analysis_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/double_mma_layout_analysis_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=double_matmul --device=cuda \
// RUN: --input="16x16xf16=[[0.0999755859375,0.2249755859375,0.07501220703125,0.0,0.07501220703125,0.2249755859375,0.175048828125,0.07501220703125,0.175048828125,0.07501220703125,0.024993896484375,0.1500244140625,0.1500244140625,0.2249755859375,0.199951171875,0.1500244140625],[0.1500244140625,0.199951171875,0.0999755859375,0.07501220703125,0.1500244140625,0.2249755859375,0.024993896484375,0.0999755859375,0.0999755859375,0.024993896484375,0.2249755859375,0.2249755859375,0.2249755859375,0.0,0.024993896484375,0.04998779296875],[0.07501220703125,0.0,0.125,0.125,0.04998779296875,0.2249755859375,0.024993896484375,0.199951171875,0.199951171875,0.07501220703125,0.1500244140625,0.2249755859375,0.024993896484375,0.175048828125,0.07501220703125,0.125],[0.04998779296875,0.024993896484375,0.0,0.2249755859375,0.07501220703125,0.024993896484375,0.024993896484375,0.0,0.07501220703125,0.1500244140625,0.1500244140625,0.175048828125,0.2249755859375,0.1500244140625,0.07501220703125,0.0999755859375],[0.125,0.0,0.199951171875,0.04998779296875,0.199951171875,0.04998779296875,0.175048828125,0.125,0.0,0.0,0.199951171875,0.024993896484375,0.2249755859375,0.1500244140625,0.024993896484375,0.0],[0.04998779296875,0.2249755859375,0.0999755859375,0.07501220703125,0.2249755859375,0.07501220703125,0.2249755859375,0.07501220703125,0.2249755859375,0.199951171875,0.125,0.07501220703125,0.04998779296875,0.199951171875,0.125,0.1500244140625],[0.1500244140625,0.125,0.175048828125,0.04998779296875,0.125,0.1500244140625,0.1500244140625,0.125,0.0999755859375,0.0,0.199951171875,0.024993896484375,0.175048828125,0.199951171875,0.125,0.0999755859375],[0.0999755859375,0.199951171875,0.0999755859375,0.0999755859375,0.2249755859375,0.0,0.175048828125,0.0999755859375,0.125,0.07501220703125,0.07501220703125,0.175048828125,0.07501220703125,0.0,0.2249755859375,0.2249755859375],[0.07501220703125,0.024993896484375,0.199951171875,0.024993896484375,0.175048828125,0.199951171875,0.0999755859375,0.024993896484375,0.0,0.0999755859375,0.0,0.0999755859375,0.2249755859375,0.175048828125,0.0,0.0],[0.024993896484375,0.0999755859375,0.2249755859375,0.2249755859375,0.125,0.2249755859375,0.04998779296875,0.04998779296875,0.04998779296875,0.024993896484375,0.0999755859375,0.2249755859375,0.024993896484375,0.024993896484375,0.0,0.07501220703125],[0.0,0.1500244140625,0.175048828125,0.1500244140625,0.2249755859375,0.024993896484375,0.1500244140625,0.0999755859375,0.024993896484375,0.0,0.125,0.04998779296875,0.125,0.199951171875,0.024993896484375,0.199951171875],[0.024993896484375,0.04998779296875,0.199951171875,0.0,0.07501220703125,0.199951171875,0.2249755859375,0.04998779296875,0.175048828125,0.0,0.199951171875,0.199951171875,0.1500244140625,0.199951171875,0.125,0.199951171875],[0.1500244140625,0.125,0.04998779296875,0.0999755859375,0.04998779296875,0.175048828125,0.04998779296875,0.0999755859375,0.2249755859375,0.199951171875,0.125,0.1500244140625,0.0999755859375,0.07501220703125,0.07501220703125,0.0999755859375],[0.0,0.04998779296875,0.125,0.024993896484375,0.04998779296875,0.199951171875,0.04998779296875,0.0999755859375,0.199951171875,0.07501220703125,0.1500244140625,0.125,0.199951171875,0.199951171875,0.0,0.125],[0.024993896484375,0.07501220703125,0.0,0.199951171875,0.024993896484375,0.024993896484375,0.024993896484375,0.175048828125,0.04998779296875,0.04998779296875,0.04998779296875,0.07501220703125,0.07501220703125,0.1500244140625,0.175048828125,0.199951171875],[0.0,0.125,0.0,0.07501220703125,0.125,0.125,0.07501220703125,0.1500244140625,0.04998779296875,0.04998779296875,0.125,0.125,0.2249755859375,0.0999755859375,0.07501220703125,0.07501220703125]]" \
// RUN: --input="16x16xf16=[[0.175048828125,0.07501220703125,0.199951171875,0.0,0.175048828125,0.125,0.199951171875,0.04998779296875,0.0999755859375,0.175048828125,0.07501220703125,0.04998779296875,0.125,0.125,0.07501220703125,0.2249755859375],[0.024993896484375,0.199951171875,0.0,0.1500244140625,0.175048828125,0.0999755859375,0.175048828125,0.1500244140625,0.2249755859375,0.07501220703125,0.199951171875,0.0999755859375,0.0999755859375,0.2249755859375,0.0999755859375,0.0999755859375],[0.2249755859375,0.2249755859375,0.125,0.175048828125,0.0,0.07501220703125,0.04998779296875,0.0,0.199951171875,0.1500244140625,0.024993896484375,0.2249755859375,0.024993896484375,0.1500244140625,0.2249755859375,0.199951171875],[0.1500244140625,0.125,0.024993896484375,0.07501220703125,0.125,0.125,0.07501220703125,0.1500244140625,0.04998779296875,0.175048828125,0.125,0.175048828125,0.175048828125,0.07501220703125,0.024993896484375,0.125],[0.2249755859375,0.125,0.2249755859375,0.1500244140625,0.0,0.0,0.1500244140625,0.125,0.024993896484375,0.125,0.0,0.024993896484375,0.175048828125,0.175048828125,0.024993896484375,0.125],[0.2249755859375,0.024993896484375,0.04998779296875,0.0,0.0,0.1500244140625,0.07501220703125,0.2249755859375,0.1500244140625,0.024993896484375,0.0,0.0999755859375,0.125,0.1500244140625,0.2249755859375,0.0],[0.125,0.0999755859375,0.0,0.0999755859375,0.199951171875,0.125,0.175048828125,0.175048828125,0.1500244140625,0.2249755859375,0.04998779296875,0.125,0.1500244140625,0.0,0.0,0.0999755859375],[0.125,0.07501220703125,0.175048828125,0.1500244140625,0.175048828125,0.0,0.04998779296875,0.125,0.125,0.024993896484375,0.0999755859375,0.175048828125,0.024993896484375,0.0,0.024993896484375,0.0],[0.2249755859375,0.024993896484375,0.0999755859375,0.04998779296875,0.125,0.07501220703125,0.0999755859375,0.024993896484375,0.125,0.125,0.125,0.024993896484375,0.125,0.04998779296875,0.0999755859375,0.07501220703125],[0.0999755859375,0.175048828125,0.199951171875,0.0999755859375,0.175048828125,0.07501220703125,0.024993896484375,0.125,0.07501220703125,0.0,0.125,0.07501220703125,0.07501220703125,0.0,0.199951171875,0.175048828125],[0.07501220703125,0.0999755859375,0.175048828125,0.07501220703125,0.125,0.1500244140625,0.0,0.0999755859375,0.2249755859375,0.199951171875,0.04998779296875,0.0,0.0,0.1500244140625,0.199951171875,0.2249755859375],[0.024993896484375,0.2249755859375,0.04998779296875,0.1500244140625,0.2249755859375,0.2249755859375,0.175048828125,0.0999755859375,0.024993896484375,0.199951171875,0.125,0.199951171875,0.175048828125,0.2249755859375,0.175048828125,0.0999755859375],[0.125,0.0999755859375,0.04998779296875,0.125,0.199951171875,0.07501220703125,0.199951171875,0.0,0.024993896484375,0.04998779296875,0.0,0.04998779296875,0.04998779296875,0.199951171875,0.1500244140625,0.0999755859375],[0.199951171875,0.0,0.125,0.04998779296875,0.07501220703125,0.175048828125,0.0999755859375,0.175048828125,0.024993896484375,0.07501220703125,0.0,0.1500244140625,0.07501220703125,0.024993896484375,0.07501220703125,0.175048828125],[0.1500244140625,0.125,0.0999755859375,0.175048828125,0.04998779296875,0.0,0.04998779296875,0.1500244140625,0.024993896484375,0.125,0.125,0.175048828125,0.125,0.0999755859375,0.175048828125,0.1500244140625],[0.07501220703125,0.199951171875,0.024993896484375,0.0999755859375,0.175048828125,0.07501220703125,0.1500244140625,0.04998779296875,0.0,0.024993896484375,0.07501220703125,0.07501220703125,0.1500244140625,0.04998779296875,0.2249755859375,0.1500244140625]]" \
diff --git a/tests/transform_dialect/cuda/eltwise_reduction.mlir b/tests/transform_dialect/cuda/eltwise_reduction.mlir
index a4b01aa..d276879 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction.mlir
@@ -43,8 +43,8 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))'
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/%S_codegen_spec.mlir | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))'
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/%S_codegen_spec.mlir | \
// RUN: FileCheck %s
// RUN: iree-compile %s --iree-hal-target-backends=cuda | \
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise.mlir
index e5f00b0..3ad0c96 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise.mlir
@@ -55,8 +55,8 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))'
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/%S_codegen_spec.mlir | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))'
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/%S_codegen_spec.mlir | \
// RUN: FileCheck %s
// RUN: iree-compile %s --iree-hal-target-backends=cuda | \
diff --git a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis.mlir b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis.mlir
index 921492e..e1cbe68 100644
--- a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis.mlir
+++ b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis.mlir
@@ -18,7 +18,7 @@
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-hal-cuda-llvm-target-arch=sm_80 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/mma_elemwise_layout_analysis_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/mma_elemwise_layout_analysis_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=matmul --device=cuda \
// RUN: --input="16x16xf16=[[0.0999755859375,0.2249755859375,0.07501220703125,0.0,0.07501220703125,0.2249755859375,0.175048828125,0.07501220703125,0.175048828125,0.07501220703125,0.024993896484375,0.1500244140625,0.1500244140625,0.2249755859375,0.199951171875,0.1500244140625],[0.1500244140625,0.199951171875,0.0999755859375,0.07501220703125,0.1500244140625,0.2249755859375,0.024993896484375,0.0999755859375,0.0999755859375,0.024993896484375,0.2249755859375,0.2249755859375,0.2249755859375,0.0,0.024993896484375,0.04998779296875],[0.07501220703125,0.0,0.125,0.125,0.04998779296875,0.2249755859375,0.024993896484375,0.199951171875,0.199951171875,0.07501220703125,0.1500244140625,0.2249755859375,0.024993896484375,0.175048828125,0.07501220703125,0.125],[0.04998779296875,0.024993896484375,0.0,0.2249755859375,0.07501220703125,0.024993896484375,0.024993896484375,0.0,0.07501220703125,0.1500244140625,0.1500244140625,0.175048828125,0.2249755859375,0.1500244140625,0.07501220703125,0.0999755859375],[0.125,0.0,0.199951171875,0.04998779296875,0.199951171875,0.04998779296875,0.175048828125,0.125,0.0,0.0,0.199951171875,0.024993896484375,0.2249755859375,0.1500244140625,0.024993896484375,0.0],[0.04998779296875,0.2249755859375,0.0999755859375,0.07501220703125,0.2249755859375,0.07501220703125,0.2249755859375,0.07501220703125,0.2249755859375,0.199951171875,0.125,0.07501220703125,0.04998779296875,0.199951171875,0.125,0.1500244140625],[0.1500244140625,0.125,0.175048828125,0.04998779296875,0.125,0.1500244140625,0.1500244140625,0.125,0.0999755859375,0.0,0.199951171875,0.024993896484375,0.175048828125,0.199951171875,0.125,0.0999755859375],[0.0999755859375,0.199951171875,0.0999755859375,0.0999755859375,0.2249755859375,0.0,0.175048828125,0.0999755859375,0.125,0.07501220703125,0.07501220703125,0.175048828125,0.07501220703125,0.0,0.2249755859375,0.2249755859375],[0.07501220703125,0.024993896484375,0.199951171875,0.024993896484375,0.175048828125,0.199951171875,0.0999755859375,0.024993896484375,0.0,0.0999755859375,0.0,0.0999755859375,0.2249755859375,0.175048828125,0.0,0.0],[0.024993896484375,0.0999755859375,0.2249755859375,0.2249755859375,0.125,0.2249755859375,0.04998779296875,0.04998779296875,0.04998779296875,0.024993896484375,0.0999755859375,0.2249755859375,0.024993896484375,0.024993896484375,0.0,0.07501220703125],[0.0,0.1500244140625,0.175048828125,0.1500244140625,0.2249755859375,0.024993896484375,0.1500244140625,0.0999755859375,0.024993896484375,0.0,0.125,0.04998779296875,0.125,0.199951171875,0.024993896484375,0.199951171875],[0.024993896484375,0.04998779296875,0.199951171875,0.0,0.07501220703125,0.199951171875,0.2249755859375,0.04998779296875,0.175048828125,0.0,0.199951171875,0.199951171875,0.1500244140625,0.199951171875,0.125,0.199951171875],[0.1500244140625,0.125,0.04998779296875,0.0999755859375,0.04998779296875,0.175048828125,0.04998779296875,0.0999755859375,0.2249755859375,0.199951171875,0.125,0.1500244140625,0.0999755859375,0.07501220703125,0.07501220703125,0.0999755859375],[0.0,0.04998779296875,0.125,0.024993896484375,0.04998779296875,0.199951171875,0.04998779296875,0.0999755859375,0.199951171875,0.07501220703125,0.1500244140625,0.125,0.199951171875,0.199951171875,0.0,0.125],[0.024993896484375,0.07501220703125,0.0,0.199951171875,0.024993896484375,0.024993896484375,0.024993896484375,0.175048828125,0.04998779296875,0.04998779296875,0.04998779296875,0.07501220703125,0.07501220703125,0.1500244140625,0.175048828125,0.199951171875],[0.0,0.125,0.0,0.07501220703125,0.125,0.125,0.07501220703125,0.1500244140625,0.04998779296875,0.04998779296875,0.125,0.125,0.2249755859375,0.0999755859375,0.07501220703125,0.07501220703125]]" \
// RUN: --input="8x16xf16=[[0.175049 0.0999756 0.0249939 0.224976 0.224976 0.199951 0.150024 0.0499878 0.224976 0.0249939 0.224976 0.150024 0.125 0.150024 0.125 0.125][0.0750122 0.175049 0.199951 0.0750122 0.224976 0.150024 0.125 0.175049 0.125 0.125 0.0249939 0.0249939 0.0999756 0.224976 0.0750122 0.0249939][0.199951 0.0750122 0 0.199951 0.125 0.0249939 0.0249939 0.125 0.224976 0 0.0499878 0 0 0.0499878 0.175049 0.0999756][0 0.0499878 0.150024 0.0999756 0.175049 0.224976 0.0750122 0.175049 0.150024 0.0249939 0 0.0999756 0.0999756 0.125 0.150024 0.175049][0.175049 0.125 0.175049 0.0999756 0 0.0249939 0.125 0.175049 0 0.175049 0 0.125 0.199951 0.150024 0.175049 0.0249939][0.125 0.125 0.0999756 0.224976 0.0750122 0.150024 0.125 0.0750122 0 0.175049 0.150024 0.150024 0.125 0 0 0][0.199951 0.0750122 0.175049 0.0999756 0.0499878 0.224976 0.0750122 0.0249939 0.150024 0.0249939 0.0750122 0.224976 0.175049 0 0.0499878 0.0249939][0.0499878 0.224976 0.150024 0.0999756 0 0.199951 0.150024 0.125 0.125 0.125 0.224976 0 0.175049 0.0999756 0.125 0]]" \
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis.mlir
index 614d031..627d8f2 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis.mlir
@@ -27,7 +27,7 @@
// RUN: --iree-hal-cuda-llvm-target-arch=sm_80 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/mma_reduction_layout_analysis_dispatch_spec.mlir \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/mma_reduction_layout_analysis_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/mma_reduction_layout_analysis_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=matmul_reduction --device=cuda \
// RUN: --input="16x16xf16=[[3.0,2.0,2.5,4.5,1.5,4.0,2.0,2.5,4.0,4.0,1.5,0.5,2.0,3.0,0.5,2.0],[2.5,2.5,0.5,3.5,0.0,2.5,3.5,1.0,0.5,0.0,3.0,4.5,0.5,0.5,0.0,3.5],[4.5,3.0,4.0,2.5,1.0,0.5,0.0,4.5,0.0,2.5,3.5,0.0,2.0,4.5,1.5,4.5],[0.0,2.0,1.5,0.0,2.0,1.5,3.0,2.0,2.0,4.0,4.0,2.5,0.0,3.0,2.0,0.5],[0.5,3.5,3.0,2.5,0.0,2.5,3.0,3.0,4.5,2.0,2.0,1.0,2.0,1.0,3.5,2.0],[0.0,4.5,2.0,4.0,2.5,2.5,1.5,1.5,1.5,3.0,3.0,0.0,2.5,0.5,2.0,2.0],[3.5,4.0,3.5,1.5,2.0,0.5,1.0,2.5,4.0,3.5,0.0,3.0,0.0,1.5,4.5,0.0],[4.5,3.5,1.0,4.5,0.5,0.0,1.5,4.5,1.5,3.5,3.0,2.5,0.0,0.5,0.0,4.0],[2.0,3.0,0.5,2.0,1.5,0.5,2.0,2.5,2.5,4.0,2.0,4.5,4.0,0.0,2.0,3.0],[2.5,4.0,4.0,3.0,2.0,2.0,4.5,0.5,4.5,1.0,2.0,0.0,4.5,1.0,3.0,0.5],[4.0,1.5,3.5,3.0,2.5,4.5,1.0,3.5,3.0,2.5,2.5,2.0,2.0,4.5,1.5,2.5],[3.0,3.0,0.0,2.5,1.0,3.0,0.0,1.5,1.5,2.5,0.5,1.0,3.0,3.5,1.5,1.5],[0.0,4.5,0.5,1.5,0.5,4.0,3.5,4.0,4.0,0.0,0.5,1.0,4.5,1.5,0.0,3.5],[2.5,2.0,2.5,1.5,3.0,0.0,2.0,1.0,2.5,4.0,0.0,4.0,4.0,1.5,3.0,2.5],[3.0,0.0,4.0,4.0,2.0,0.5,1.0,3.5,4.0,2.5,4.0,4.5,0.0,3.0,1.5,2.5],[0.5,0.5,2.5,4.0,1.0,2.5,0.5,4.5,2.0,3.0,1.5,4.5,1.5,4.5,0.5,1.5]]" \
// RUN: --input="16x16xf16=[[3.5,3.0,4.5,3.0,3.0,0.0,2.0,2.5,2.0,0.0,4.5,2.5,0.5,0.0,4.0,3.5],[0.0,0.5,2.0,4.5,0.0,4.0,1.5,3.5,0.5,2.5,3.5,1.5,3.5,4.5,4.0,3.0],[3.0,3.5,2.5,1.5,1.5,1.5,0.5,4.5,0.0,3.5,4.0,0.0,0.0,2.0,0.5,1.0],[1.5,4.0,3.5,3.5,0.0,0.0,0.0,2.0,3.0,1.5,0.0,3.0,0.0,2.5,2.0,3.0],[3.5,4.0,2.5,1.5,3.0,2.0,3.0,4.5,1.5,3.0,2.0,3.5,2.5,4.5,0.5,3.5],[0.0,0.0,0.0,0.5,1.0,2.5,1.5,1.0,2.5,1.5,0.0,1.5,1.5,2.0,4.5,2.5],[4.0,1.5,3.0,2.5,2.5,3.5,2.0,4.0,1.5,2.5,0.5,4.0,1.0,4.5,3.5,0.0],[1.0,2.0,4.0,4.5,4.5,3.5,0.0,1.0,4.5,3.5,2.0,3.0,0.5,4.0,3.5,1.5],[1.0,0.0,2.5,4.5,0.0,2.0,0.0,2.5,3.0,4.0,2.5,0.5,3.5,0.0,3.5,1.0],[0.0,3.5,4.0,0.0,0.0,4.5,1.0,3.5,1.5,3.0,2.0,1.0,0.5,0.5,2.0,0.0],[1.5,0.0,4.5,2.0,4.5,4.5,3.5,3.0,2.5,4.5,0.5,0.5,0.0,4.5,0.0,4.0],[4.5,3.5,4.0,4.0,1.5,4.0,1.0,4.0,2.5,0.5,4.5,3.5,3.5,0.5,4.5,3.0],[0.0,3.0,2.5,1.0,1.5,2.0,1.0,1.5,4.0,2.5,3.5,1.0,3.5,2.5,3.5,4.5],[1.5,4.5,2.0,2.0,2.0,0.5,4.0,2.0,4.0,3.5,4.0,1.0,1.5,2.5,1.0,0.0],[0.0,0.0,1.0,2.5,3.5,2.5,4.0,0.0,2.0,2.0,4.5,0.5,1.0,3.5,3.0,2.5],[2.0,2.0,0.5,2.0,4.5,2.5,3.0,1.5,4.5,2.0,3.5,3.0,1.0,2.0,1.5,2.0]]" |\
diff --git a/tests/transform_dialect/cuda/mma_using_layout_analysis.mlir b/tests/transform_dialect/cuda/mma_using_layout_analysis.mlir
index 9e0fc41..d0ca10c 100644
--- a/tests/transform_dialect/cuda/mma_using_layout_analysis.mlir
+++ b/tests/transform_dialect/cuda/mma_using_layout_analysis.mlir
@@ -10,7 +10,7 @@
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-hal-cuda-llvm-target-arch=sm_80 \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/mma_using_layout_analysis_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/mma_using_layout_analysis_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=matmul --device=cuda \
// RUN: --input="16x16xf16=[[1.0,1.125,1.25,1.375,1.5,1.625,1.75,1.875,2.0,2.125,2.25,2.375,2.5,2.625,2.75,2.875],[3.0,3.125,3.25,3.375,3.5,3.625,3.75,3.875,4.0,4.125,4.25,4.375,4.5,4.625,4.75,4.875],[5.0,5.125,5.25,5.375,5.5,5.625,5.75,5.875,6.0,6.125,6.25,6.375,6.5,6.625,6.75,6.875],[7.0,7.125,7.25,7.375,7.5,7.625,7.75,7.875,8.0,8.125,8.25,8.375,8.5,8.625,8.75,8.875],[9.0,9.125,9.25,9.375,9.5,9.625,9.75,9.875,10.0,10.125,10.25,10.375,10.5,10.625,10.75,10.875],[11.0,11.125,11.25,11.375,11.5,11.625,11.75,11.875,12.0,12.125,12.25,12.375,12.5,12.625,12.75,12.875],[13.0,13.125,13.25,13.375,13.5,13.625,13.75,13.875,14.0,14.125,14.25,14.375,14.5,14.625,14.75,14.875],[15.0,15.125,15.25,15.375,15.5,15.625,15.75,15.875,16.0,16.125,16.25,16.375,16.5,16.625,16.75,16.875],[17.0,17.125,17.25,17.375,17.5,17.625,17.75,17.875,18.0,18.125,18.25,18.375,18.5,18.625,18.75,18.875],[19.0,19.125,19.25,19.375,19.5,19.625,19.75,19.875,20.0,20.125,20.25,20.375,20.5,20.625,20.75,20.875],[21.0,21.125,21.25,21.375,21.5,21.625,21.75,21.875,22.0,22.125,22.25,22.375,22.5,22.625,22.75,22.875],[23.0,23.125,23.25,23.375,23.5,23.625,23.75,23.875,24.0,24.125,24.25,24.375,24.5,24.625,24.75,24.875],[25.0,25.125,25.25,25.375,25.5,25.625,25.75,25.875,26.0,26.125,26.25,26.375,26.5,26.625,26.75,26.875],[27.0,27.125,27.25,27.375,27.5,27.625,27.75,27.875,28.0,28.125,28.25,28.375,28.5,28.625,28.75,28.875],[29.0,29.125,29.25,29.375,29.5,29.625,29.75,29.875,30.0,30.125,30.25,30.375,30.5,30.625,30.75,30.875],[31.0,31.125,31.25,31.375,31.5,31.625,31.75,31.875,32.0,32.125,32.25,32.375,32.5,32.625,32.75,32.875]]" \
// RUN: --input="16x8xf16=[[1.0,1.125,1.25,1.375,1.5,1.625,1.75,1.875],[2.0,2.125,2.25,2.375,2.5,2.625,2.75,2.875],[3.0,3.125,3.25,3.375,3.5,3.625,3.75,3.875],[4.0,4.125,4.25,4.375,4.5,4.625,4.75,4.875],[5.0,5.125,5.25,5.375,5.5,5.625,5.75,5.875],[6.0,6.125,6.25,6.375,6.5,6.625,6.75,6.875],[7.0,7.125,7.25,7.375,7.5,7.625,7.75,7.875],[8.0,8.125,8.25,8.375,8.5,8.625,8.75,8.875],[9.0,9.125,9.25,9.375,9.5,9.625,9.75,9.875],[10.0,10.125,10.25,10.375,10.5,10.625,10.75,10.875],[11.0,11.125,11.25,11.375,11.5,11.625,11.75,11.875],[12.0,12.125,12.25,12.375,12.5,12.625,12.75,12.875],[13.0,13.125,13.25,13.375,13.5,13.625,13.75,13.875],[14.0,14.125,14.25,14.375,14.5,14.625,14.75,14.875],[15.0,15.125,15.25,15.375,15.5,15.625,15.75,15.875],[16.0,16.125,16.25,16.375,16.5,16.625,16.75,16.875]]" |\
diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir
index d771955..e814353 100644
--- a/tests/transform_dialect/cuda/reduction.mlir
+++ b/tests/transform_dialect/cuda/reduction.mlir
@@ -23,9 +23,9 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
@@ -33,7 +33,7 @@
/// Constant JIT'ing must be disabled because the transform-dialect debug
/// flags leak to the JIT session, which doesn't know what to do with them.
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="8x64xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
diff --git a/tests/transform_dialect/cuda/reduction_eltwise.mlir b/tests/transform_dialect/cuda/reduction_eltwise.mlir
index 06ab2bc..5e75858 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise.mlir
@@ -34,14 +34,14 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_eltwise_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_eltwise_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_eltwise_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_eltwise_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="8x64xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
diff --git a/tests/transform_dialect/cuda/reduction_v2.mlir b/tests/transform_dialect/cuda/reduction_v2.mlir
index 893191c..de96cc1 100644
--- a/tests/transform_dialect/cuda/reduction_v2.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2.mlir
@@ -23,14 +23,14 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_v2_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_v2_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="33x1024xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
diff --git a/tests/transform_dialect/cuda/reduction_v2_uneven.mlir b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
index 57a6fb8..ebc5d8a 100644
--- a/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
@@ -23,14 +23,14 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_v2_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/reduction_v2_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="33x34567xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir
index 161fa32..299b5cb 100644
--- a/tests/transform_dialect/cuda/softmax.mlir
+++ b/tests/transform_dialect/cuda/softmax.mlir
@@ -5,8 +5,8 @@
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_codegen_spec.mlir \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false | \
// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
@@ -16,7 +16,7 @@
// RUN: --iree-opt-const-expr-hoisting=false --iree-opt-const-eval=false \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=softmax --device=cuda | \
// RUN: FileCheck %s
diff --git a/tests/transform_dialect/cuda/softmax_partial.mlir b/tests/transform_dialect/cuda/softmax_partial.mlir
index 65c51f7..6f4ca42 100644
--- a/tests/transform_dialect/cuda/softmax_partial.mlir
+++ b/tests/transform_dialect/cuda/softmax_partial.mlir
@@ -4,8 +4,8 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_partial_codegen_spec.mlir \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_partial_codegen_spec.mlir \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false | \
// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
@@ -14,7 +14,7 @@
/// Constant JIT'ing must be disabled because the transform-dialect debug
/// flags leak to the JIT session, which doesn't know what to do with them.
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_partial_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_partial_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=softmax_partial --device=cuda | \
// RUN: FileCheck %s
diff --git a/tests/transform_dialect/cuda/softmax_v2.mlir b/tests/transform_dialect/cuda/softmax_v2.mlir
index aa590e6..2a556d1 100644
--- a/tests/transform_dialect/cuda/softmax_v2.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2.mlir
@@ -4,9 +4,9 @@
// RUN: --iree-flow-fuse-multi-use \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_v2_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
@@ -15,7 +15,7 @@
/// flags leak to the JIT session, which doesn't know what to do with them.
// RUN: --iree-flow-fuse-multi-use \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_v2_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/softmax_v2_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=softmax --device=cuda | \
// RUN: FileCheck %s
diff --git a/tests/transform_dialect/cuda/vecadd2d.mlir b/tests/transform_dialect/cuda/vecadd2d.mlir
index 9279355..7e03154 100644
--- a/tests/transform_dialect/cuda/vecadd2d.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d.mlir
@@ -37,9 +37,9 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/vecadd2d_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
@@ -47,9 +47,9 @@
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-materialize-user-configs, iree-llvmgpu-lower-executable-target)))' \
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec_partial_tile.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/vecadd2d_codegen_spec_partial_tile.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK-PARTIAL-TILE
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
@@ -57,7 +57,7 @@
/// Constant JIT'ing must be disabled because the transform-dialect debug
/// flags leak to the JIT session, which doesn't know what to do with them.
// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
+// RUN: --iree-codegen-use-transform-dialect-strategy=%p/vecadd2d_codegen_spec.mlir | \
// RUN: iree-run-module --module=- --function=vecadd2d --device=cuda |\
// RUN: FileCheck %s --check-prefix=EXEC