Revert "[Codegen] Re-Enable transform dialect configuration strategy sample (#15787)" (#16097)
This reverts commit 3b534c4299d28fd5fd04801a62955b3b25cba543.
There were breakages on ToM on Windows.
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
index cd588bc..c8f00fb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
@@ -38,43 +38,13 @@
"iree-codegen-transform-dialect-library",
llvm::cl::desc(
"File path to a module containing a library of transform dialect"
- "strategies. Can be suffixed with the name of a transform sequence"
- "within the library to run as preprocessing per executable variant."
- "This is specified as <file-path>::<sequence-name>. If not specified,"
- "this will default to `__kernel_config`."),
+ "strategies"),
llvm::cl::init(""));
namespace {
static const char kTranslationInfoAttrName[] = "translation_info";
-enum StrategyRunResult {
- Success = 0,
- NotFound = 1,
- Failed = 2,
-};
-
-static StrategyRunResult
-runTransformConfigurationStrategy(Operation *payloadRoot,
- StringRef entryPointName,
- ModuleOp &transformLibrary) {
- /// If we have a symbol, verify the existence of the symbol within the
- /// transform library.
- Operation *entryPoint = transform::detail::findTransformEntryPoint(
- payloadRoot, transformLibrary, entryPointName);
- if (!entryPoint) {
- return StrategyRunResult::NotFound;
- }
-
- transform::TransformOptions options;
- if (failed(transform::applyTransformNamedSequence(
- payloadRoot, entryPoint, transformLibrary,
- options.enableExpensiveChecks(true)))) {
- return StrategyRunResult::Failed;
- }
- return StrategyRunResult::Success;
-}
-
struct MaterializeUserConfigsPass
: public MaterializeUserConfigsBase<MaterializeUserConfigsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -88,98 +58,21 @@
getAllEntryPoints(moduleOp);
MLIRContext *context = moduleOp.getContext();
- // Parse the file path and kernel config strategy from flags. There are
- // three possible usage flows:
- // 1. Specify only a transform dialect library, e.g.
- // --iree-codegen-transform-dialect-library=/path/to/library.mlir
- // This will load the transform library and attempt to use two default
- // strategies, `__kernel_config` before strategy configuration (i.e.
- // now), and `__transform_main` for codegen. If `__kernel_config` is
- // present in the library, it will run it and use the annotations set
- // by it. If not present, `__transform_main` will be broadcasted to all
- // dispatches.
- //
- // 2. Specify a library path with a strategy name instead of
- // `__transform_main`. This is the same as (1) except it uses a
- // different default strategy.
- //
- // 3. Specify a library path suffixed with a kernel config entry point.
- // This will throw an error if the specified kernel config strategy is
- // not found.
- SmallVector<StringRef, 2> parts;
- llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
- parts, "::");
- if (parts.size() > 2) {
- variantOp.emitError()
- << "Invalid transform library path and sequence name "
- << clCodegenTransformDialectLibraryFileName;
- return signalPassFailure();
- }
- bool hasTransformConfig = parts.size() == 2;
- bool hasTransformLibrary = parts.size() >= 1;
- bool hasTransformStrategy = !clCodegenTransformDialectStrategyName.empty();
-
- std::string libraryFileName;
- if (hasTransformLibrary) {
- if (parts[0].empty()) {
- variantOp.emitError() << "Cannot specify an empty library path";
- return signalPassFailure();
- }
- libraryFileName = parts[0];
- }
-
- std::string entrySequenceName;
- if (hasTransformConfig) {
- if (parts[1].empty()) {
- variantOp.emitError() << "Cannot specify an empty sequence name";
- return signalPassFailure();
- }
- entrySequenceName = parts[1];
- }
-
LDBG("MaterializeUserConfigsPass on variant: " << variantOp);
std::optional<ModuleOp> transformLibrary = std::nullopt;
- if (hasTransformLibrary) {
+ if (!clCodegenTransformDialectLibraryFileName.empty()) {
auto dialect =
context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
- auto maybeTransformLibrary =
- dialect->getOrLoadTransformLibraryModule(libraryFileName);
+ auto maybeTransformLibrary = dialect->getOrLoadTransformLibraryModule(
+ clCodegenTransformDialectLibraryFileName);
if (failed(maybeTransformLibrary)) {
- variantOp.emitError()
- << "failed to load transform library module: " << libraryFileName;
+ variantOp.emitError() << "failed to load transform library module: "
+ << clCodegenTransformDialectLibraryFileName;
return signalPassFailure();
}
transformLibrary = *maybeTransformLibrary;
- LDBG("--found transform library @" << libraryFileName);
- }
-
- // Run the user specified transform configuration, or attempt to run
- // `__kernel_config` if no sequence name is specified.
- if (hasTransformConfig) {
- // If we get here, the transform library must necessarily have been
- // loaded.
- assert(transformLibrary && *transformLibrary &&
- "Unexpected unloaded transform library");
- if (runTransformConfigurationStrategy(variantOp, entrySequenceName,
- *transformLibrary) !=
- StrategyRunResult::Success) {
- variantOp.emitError() << "transform kernel config strategy `"
- << entrySequenceName << "` failed to apply";
- return signalPassFailure();
- }
- } else if (transformLibrary && (*transformLibrary)) {
- StrategyRunResult res = runTransformConfigurationStrategy(
- variantOp, "__kernel_config", *transformLibrary);
- if (res == StrategyRunResult::Failed) {
- variantOp.emitError()
- << "default transform __kernel_config strategy failed to apply";
- return signalPassFailure();
- }
- // If `__kernel_config` was found and ran successfully, indicate that
- // there was a config strategy.
- if (res == StrategyRunResult::Success) {
- hasTransformConfig = true;
- }
+ LDBG("--found transform library @"
+ << clCodegenTransformDialectLibraryFileName);
}
IREE::Codegen::DispatchLoweringPassPipeline tdPipeline =
@@ -188,9 +81,10 @@
// Here we always set the pipeline strategy to transform dialect if the
// flag is non-empty to ensure we pick the right lowering pipeline in the
// event a strategy symbol is defined.
- if ((hasTransformLibrary && !hasTransformConfig) || hasTransformStrategy) {
+ if (!clCodegenTransformDialectLibraryFileName.empty() ||
+ !clCodegenTransformDialectStrategyName.empty()) {
StringRef strategyName =
- (!hasTransformStrategy)
+ (clCodegenTransformDialectStrategyName.empty())
? StringRef(
transform::TransformDialect::kTransformEntryPointSymbolName)
: clCodegenTransformDialectStrategyName;
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index c3c1d3a..95fe68e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -259,8 +259,7 @@
/// Create an IREE-specific Transform dialect interpreter pass with all
/// registrations necessary for IREE.
-std::unique_ptr<Pass>
-createTransformDialectInterpreterPass(StringRef transformSequenceName = "");
+std::unique_ptr<Pass> createTransformDialectInterpreterPass();
/// Pass to propagate type to avoid generating load/stores of illegal types.
std::unique_ptr<OperationPass<func::FuncOp>> createTypePropagationPass();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index fb87d06..4852485 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -76,19 +76,9 @@
extern llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName;
/// Create a Transform dialect interpreter pass.
-std::unique_ptr<Pass>
-createTransformDialectInterpreterPass(StringRef transformSequenceName) {
- StringRef strategyName = transformSequenceName.empty()
- ? clCodegenTransformDialectStrategyName
- : transformSequenceName;
- StringRef libraryPath = "";
- SmallVector<StringRef, 2> parts;
- llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
- parts, "::");
- if (!parts.empty()) {
- libraryPath = parts[0];
- }
- return std::make_unique<TransformDialectInterpreterPass>(libraryPath,
- strategyName);
+std::unique_ptr<Pass> createTransformDialectInterpreterPass() {
+ return std::make_unique<TransformDialectInterpreterPass>(
+ clCodegenTransformDialectLibraryFileName,
+ clCodegenTransformDialectStrategyName);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 091ca37..aa0cd8a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -203,12 +203,9 @@
break;
}
// Transform-dialect pipelines.
- case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
- SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
- addTransformDialectPasses(
- pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
+ case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
+ addTransformDialectPasses(pipeline);
break;
- }
default:
moduleOp.emitOpError("Unsupported pipeline on CPU target.");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index b5bb3c8..4f876ba 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -560,11 +560,10 @@
addCPUBufferizePasses(nestedModulePM);
}
-void addTransformDialectPasses(OpPassManager &passManager,
- StringRef entryPoint) {
+void addTransformDialectPasses(OpPassManager &passManager) {
// Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
// schedule once applied.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index c964342..74d688d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -145,8 +145,7 @@
bool lowerToVectors = true);
/// Transform dialect-based common.
-void addTransformDialectPasses(OpPassManager &passManager,
- StringRef entryPoint);
+void addTransformDialectPasses(OpPassManager &passManager);
// Populates the passes needed to do tiling, decomposing, and vectorizing the
// convolution ops.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index 4eb31ca..cef1407 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -108,12 +108,9 @@
addGPUPackUnPackPasses(pipeline);
break;
// Transform-dialect pipelines.
- case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
- SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
- addGPUTransformDialectPasses(
- pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
+ case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
+ addGPUTransformDialectPasses(pipeline);
break;
- }
// no pipeline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
return;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index eddea59..e529e8a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -597,10 +597,9 @@
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugPayloadTag;
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugTransformTag;
-void addGPUTransformDialectPasses(OpPassManager &passManager,
- StringRef entryPoint) {
+void addGPUTransformDialectPasses(OpPassManager &passManager) {
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 08f83a9..9860197 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -40,7 +40,7 @@
void addGPUSimpleDistributePassPipeline(OpPassManager &pm);
/// Transform dialect-based path.
-void addGPUTransformDialectPasses(OpPassManager &pm, StringRef entryPoint);
+void addGPUTransformDialectPasses(OpPassManager &pm);
/// Lowering transpose using shared memory.
void addGPUTransposePassPipeline(OpPassManager &pm);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index a6611e8..13f7dcd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -263,10 +263,9 @@
spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
}
-void addSPIRVTransformDialectPasses(OpPassManager &passManager,
- StringRef entryPoint) {
+void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
+ mlir::iree_compiler::createTransformDialectInterpreterPass());
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
@@ -646,9 +645,8 @@
nestedModulePM.addPass(createCSEPass());
}
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
- StringRef entryPoint) {
- addSPIRVTransformDialectPasses(pm, entryPoint);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {
+ addSPIRVTransformDialectPasses(pm);
// Run GenericVectorization pass additionally to convert vectors into forms
// needed for SPIR-V.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index 84cd3c7..dc4a22c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -44,8 +44,7 @@
void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm);
/// Pass pipeline to lower IREE HAL executables via transform dialect schedules.
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
- StringRef entryPoint);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm);
/// Pass pipeline to lower winograd ops. This pipeline follows the
/// SPIRVBaseVectorize pipeline with the following exception:
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 14ec33f..3943e99 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -95,12 +95,9 @@
case CodeGenPipeline::SPIRVWinogradVectorize:
addSPIRVWinogradVectorizePassPipeline(pipeline);
break;
- case CodeGenPipeline::TransformDialectCodegen: {
- SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
- addSPIRVTransformDialectPassPipeline(
- pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
+ case CodeGenPipeline::TransformDialectCodegen:
+ addSPIRVTransformDialectPassPipeline(pipeline);
break;
- }
// No pipeline specified, nothing to do.
case CodeGenPipeline::None:
return;
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 95b80af..c5eab46 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -107,29 +107,28 @@
}
/// We test first with threading off so that the printers are legible.
-// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
-// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
-// RUN: --compile-from=executable-sources \
-// RUN: --compile-to=executable-targets \
-// RUN: --mlir-disable-threading | \
-// RUN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
+// R-UN: iree-compile %s --iree-hal-target-backends=vulkan \
+// R-UN: --iree-codegen-use-transform-dialect-strategy=transform_main \
+// R-UN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// R-UN: --compile-from=executable-sources \
+// R-UN: --compile-to=executable-targets \
+// R-UN: --mlir-disable-threading | \
+// R-UN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
-// CODEGEN-PRINTER: IR printer: Setting matmul strategy to custom_transform_strategy
-// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
+// CODEGEN-PRINTER: IR printer: Setting matmul strategy to default top-level
+// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @transform_main
// CODEGEN-PRINTER: IR printer: Setting reduce strategy to base vectorize top-level
// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<SPIRVBaseVectorize>, workgroup_size = [16 : index, 1 : index, 1 : index]
/// Then test with threading to make sure it runs
// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
-// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
+// RUN: --iree-codegen-use-transform-dialect-strategy=@transform_main \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
// RUN: --compile-from=executable-sources \
// RUN: --compile-to=executable-targets \
// RUN: --mlir-disable-threading | \
// RUN: FileCheck %s --check-prefixes=CODEGEN
-// CODEGEN: Ran custom_transform_strategy
// CODEGEN: spirv.func @example_module_dispatch_0_generic_80_f32
-// CODEGEN: hal.executable private @example_module_dispatch_1
-// CODEGEN: #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
-// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
+// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
// CODEGEN: spirv.func @example_module_dispatch_2_generic_16x16_f32
diff --git a/samples/transform_dialect/transform_library.mlir b/samples/transform_dialect/transform_library.mlir
index d4ca3b4..3bb75ad 100644
--- a/samples/transform_dialect/transform_library.mlir
+++ b/samples/transform_dialect/transform_library.mlir
@@ -1,76 +1,13 @@
module attributes { transform.with_named_sequence } {
- // Example of a custom matmul strategy. The target matmul is annotated with
- // the name of this strategy down below before strategy selection, overriding
- // default IREE codegen.
- transform.named_sequence @custom_transform_strategy(
- %variant_op: !transform.any_op {transform.consumed}) {
- // Step 1. Re-match the matmul
- // ===========================================================================
- %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
- // Step 2. Tile to grid
- // ===========================================================================
- %grid_reduction, %forall_grid =
- transform.structured.tile_using_forall %matmul tile_sizes [16, 16] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
- // Step 3. Vectorize
- // ===========================================================================
- %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
- transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- } : !transform.any_op
- %func_1 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op
-
- // Step 4. Bufferize
- // ===========================================================================
- transform.apply_patterns to %func_1 {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- } : !transform.any_op
- transform.apply_patterns to %func_1 {
- transform.apply_patterns.tensor.reassociative_reshape_folding
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_cse %func_1 : !transform.any_op
- transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
- transform.apply_patterns to %func_1 {
- transform.apply_patterns.linalg.erase_unnecessary_inputs
- } : !transform.any_op
- %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
-
- // Step 6. Post-bufferization vector distribution
- // ===========================================================================
- %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
- transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
- transform.iree.map_nested_forall_to_gpu_threads %func_7
- workgroup_dims = [4, 8, 1] : (!transform.any_op) -> ()
-
- // Step 7. Do layout analysis and lower to mma
- // ===========================================================================
- %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
- %func_11 = transform.iree.layout_analysis_and_distribution %func_10 : (!transform.any_op) -> (!transform.any_op)
- transform.print {name = "Ran custom_transform_strategy"}
+ // Print and send it down normal IREE codegen.
+ transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.consumed}) {
+ %1 = transform.structured.generalize %matmul : (!transform.any_op) -> !transform.any_op
+ transform.print {name = "Setting matmul strategy to default"}
transform.yield
}
- // Send it down a custom transform dialect pipeline.
- transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.readonly}) {
- %variant_op = transform.get_parent_op %matmul {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
- %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %subgroup_reduce = transform.param.constant #iree_codegen.translation_info<TransformDialectCodegen
- codegen_spec = @custom_transform_strategy> -> !transform.any_param
- transform.annotate %exports "translation_info" = %subgroup_reduce : !transform.any_op, !transform.any_param
- transform.print {name = "Setting matmul strategy to custom_transform_strategy"}
- transform.yield
- }
-
- // Send it down subgroup reduce with a custom tiling configuration.
- transform.named_sequence @use_base_vectorize(%reduce: !transform.any_op {transform.readonly}) {
+ // Send it down subgroup reduce.
+ transform.named_sequence @use_subgroup_reduce(%reduce: !transform.any_op {transform.readonly}) {
%variant_op = transform.get_parent_op %reduce {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
%lowering_config = transform.param.constant #iree_codegen.lowering_config<tile_sizes = [[8, 0], [1, 0], [0, 0, 4]]> -> !transform.any_param
transform.annotate %reduce "lowering_config" = %lowering_config : !transform.any_op, !transform.any_param
@@ -105,34 +42,10 @@
transform.yield %matched : !transform.any_op
}
- // An example of a custom transform dialect based kernel config. Note that
- // because of the way `transform.foreach_match` works, the callback cannot
- // manipulate IR beyond the op *given* to the matcher, as foreach_match will
- // attempt to keep walking the IR even after a successful match. The expected
- // flow for a strategy like this is as follows:
- //
- // Author an entry point like this (@kernel_config) that walks the IR and
- // attempts to annotate the dispatch with the codegen strategy to use, i.e.
- // transform.foreach_match in %variant_op
- // @matcher_0 -> @annotator_0,
- // @matcher_1 -> @annotator_1,
- // ...
- //
- // the annotators should attach an #iree_codegen.translation_info attribute
- // to the `hal.executable.export` ops within the variant as well as any
- // relevant op specific tile sizes (and other important attributes like
- // workgroup_size and subgroup_size, if relevant). This will then get handed
- // off to backend specific kernel config, which will let these user configs
- // pass through unperturbed.
- //
- // To couple this with a transform dialect based codegen strategy, the target
- // codegen strategy can be included inline with this library and relevant ops
- // can be annotated with `TransformDialectCodegen` as the lowering pipeline,
- // with a reference to the strategy to use (see an example above).
- transform.named_sequence @kernel_config(%variant_op: !transform.any_op {transform.consumed}) {
+ transform.named_sequence @transform_main(%variant_op: !transform.any_op {transform.consumed}) {
transform.foreach_match in %variant_op
@match_matmul -> @custom_matmul,
- @match_reduce -> @use_base_vectorize
+ @match_reduce -> @use_subgroup_reduce
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}