[Codegen] Re-Enable transform dialect configuration strategy sample (#15787)
This re-adds a mechanism for setting strategy configurations using the
transform dialect through
--iree-codegen-use-transform-dialect-configuration as well as re-enables
the associated sample.
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
index c8f00fb..cd588bc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
@@ -38,13 +38,43 @@
"iree-codegen-transform-dialect-library",
llvm::cl::desc(
"File path to a module containing a library of transform dialect"
- "strategies"),
+ "strategies. Can be suffixed with the name of a transform sequence"
+ "within the library to run as preprocessing per executable variant."
+ "This is specified as <file-path>::<sequence-name>. If not specified,"
+ "this will default to `__kernel_config`."),
llvm::cl::init(""));
namespace {
static const char kTranslationInfoAttrName[] = "translation_info";
+enum StrategyRunResult {
+ Success = 0,
+ NotFound = 1,
+ Failed = 2,
+};
+
+static StrategyRunResult
+runTransformConfigurationStrategy(Operation *payloadRoot,
+ StringRef entryPointName,
+ ModuleOp &transformLibrary) {
+ /// If we have a symbol, verify the existence of the symbol within the
+ /// transform library.
+ Operation *entryPoint = transform::detail::findTransformEntryPoint(
+ payloadRoot, transformLibrary, entryPointName);
+ if (!entryPoint) {
+ return StrategyRunResult::NotFound;
+ }
+
+ transform::TransformOptions options;
+ if (failed(transform::applyTransformNamedSequence(
+ payloadRoot, entryPoint, transformLibrary,
+ options.enableExpensiveChecks(true)))) {
+ return StrategyRunResult::Failed;
+ }
+ return StrategyRunResult::Success;
+}
+
struct MaterializeUserConfigsPass
: public MaterializeUserConfigsBase<MaterializeUserConfigsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -58,21 +88,98 @@
getAllEntryPoints(moduleOp);
MLIRContext *context = moduleOp.getContext();
+ // Parse the file path and kernel config strategy from flags. There are
+ // three possible usage flows:
+ // 1. Specify only a transform dialect library, e.g.
+ // --iree-codegen-transform-dialect-library=/path/to/library.mlir
+ // This will load the transform library and attempt to use two default
+ // strategies, `__kernel_config` before strategy configuration (i.e.
+ // now), and `__transform_main` for codegen. If `__kernel_config` is
+ // present in the library, it will run it and use the annotations set
+ // by it. If not present, `__transform_main` will be broadcasted to all
+ // dispatches.
+ //
+ // 2. Specify a library path with a strategy name instead of
+ // `__transform_main`. This is the same as (1) except it uses a
+ // different default strategy.
+ //
+ // 3. Specify a library path suffixed with a kernel config entry point.
+ // This will throw an error if the specified kernel config strategy is
+ // not found.
+ SmallVector<StringRef, 2> parts;
+ llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
+ parts, "::");
+ if (parts.size() > 2) {
+ variantOp.emitError()
+ << "Invalid transform library path and sequence name "
+ << clCodegenTransformDialectLibraryFileName;
+ return signalPassFailure();
+ }
+ bool hasTransformConfig = parts.size() == 2;
+ bool hasTransformLibrary = parts.size() >= 1;
+ bool hasTransformStrategy = !clCodegenTransformDialectStrategyName.empty();
+
+ std::string libraryFileName;
+ if (hasTransformLibrary) {
+ if (parts[0].empty()) {
+ variantOp.emitError() << "Cannot specify an empty library path";
+ return signalPassFailure();
+ }
+ libraryFileName = parts[0];
+ }
+
+ std::string entrySequenceName;
+ if (hasTransformConfig) {
+ if (parts[1].empty()) {
+ variantOp.emitError() << "Cannot specify an empty sequence name";
+ return signalPassFailure();
+ }
+ entrySequenceName = parts[1];
+ }
+
LDBG("MaterializeUserConfigsPass on variant: " << variantOp);
std::optional<ModuleOp> transformLibrary = std::nullopt;
- if (!clCodegenTransformDialectLibraryFileName.empty()) {
+ if (hasTransformLibrary) {
auto dialect =
context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
- auto maybeTransformLibrary = dialect->getOrLoadTransformLibraryModule(
- clCodegenTransformDialectLibraryFileName);
+ auto maybeTransformLibrary =
+ dialect->getOrLoadTransformLibraryModule(libraryFileName);
if (failed(maybeTransformLibrary)) {
- variantOp.emitError() << "failed to load transform library module: "
- << clCodegenTransformDialectLibraryFileName;
+ variantOp.emitError()
+ << "failed to load transform library module: " << libraryFileName;
return signalPassFailure();
}
transformLibrary = *maybeTransformLibrary;
- LDBG("--found transform library @"
- << clCodegenTransformDialectLibraryFileName);
+ LDBG("--found transform library @" << libraryFileName);
+ }
+
+ // Run the user specified transform configuration, or attempt to run
+ // `__kernel_config` if no sequence name is specified.
+ if (hasTransformConfig) {
+ // If we get here, the transform library must necessarily have been
+ // loaded.
+ assert(transformLibrary && *transformLibrary &&
+ "Unexpected unloaded transform library");
+ if (runTransformConfigurationStrategy(variantOp, entrySequenceName,
+ *transformLibrary) !=
+ StrategyRunResult::Success) {
+ variantOp.emitError() << "transform kernel config strategy `"
+ << entrySequenceName << "` failed to apply";
+ return signalPassFailure();
+ }
+ } else if (transformLibrary && (*transformLibrary)) {
+ StrategyRunResult res = runTransformConfigurationStrategy(
+ variantOp, "__kernel_config", *transformLibrary);
+ if (res == StrategyRunResult::Failed) {
+ variantOp.emitError()
+ << "default transform __kernel_config strategy failed to apply";
+ return signalPassFailure();
+ }
+ // If `__kernel_config` was found and ran successfully, indicate that
+ // there was a config strategy.
+ if (res == StrategyRunResult::Success) {
+ hasTransformConfig = true;
+ }
}
IREE::Codegen::DispatchLoweringPassPipeline tdPipeline =
@@ -81,10 +188,9 @@
// Here we always set the pipeline strategy to transform dialect if the
// flag is non-empty to ensure we pick the right lowering pipeline in the
// event a strategy symbol is defined.
- if (!clCodegenTransformDialectLibraryFileName.empty() ||
- !clCodegenTransformDialectStrategyName.empty()) {
+ if ((hasTransformLibrary && !hasTransformConfig) || hasTransformStrategy) {
StringRef strategyName =
- (clCodegenTransformDialectStrategyName.empty())
+ (!hasTransformStrategy)
? StringRef(
transform::TransformDialect::kTransformEntryPointSymbolName)
: clCodegenTransformDialectStrategyName;
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 95fe68e..c3c1d3a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -259,7 +259,8 @@
/// Create an IREE-specific Transform dialect interpreter pass with all
/// registrations necessary for IREE.
-std::unique_ptr<Pass> createTransformDialectInterpreterPass();
+std::unique_ptr<Pass>
+createTransformDialectInterpreterPass(StringRef transformSequenceName = "");
/// Pass to propagate type to avoid generating load/stores of illegal types.
std::unique_ptr<OperationPass<func::FuncOp>> createTypePropagationPass();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index 4852485..fb87d06 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -76,9 +76,19 @@
extern llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName;
/// Create a Transform dialect interpreter pass.
-std::unique_ptr<Pass> createTransformDialectInterpreterPass() {
- return std::make_unique<TransformDialectInterpreterPass>(
- clCodegenTransformDialectLibraryFileName,
- clCodegenTransformDialectStrategyName);
+std::unique_ptr<Pass>
+createTransformDialectInterpreterPass(StringRef transformSequenceName) {
+ StringRef strategyName = transformSequenceName.empty()
+ ? clCodegenTransformDialectStrategyName
+ : transformSequenceName;
+ StringRef libraryPath = "";
+ SmallVector<StringRef, 2> parts;
+ llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
+ parts, "::");
+ if (!parts.empty()) {
+ libraryPath = parts[0];
+ }
+ return std::make_unique<TransformDialectInterpreterPass>(libraryPath,
+ strategyName);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index aa0cd8a..091ca37 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -203,9 +203,12 @@
break;
}
// Transform-dialect pipelines.
- case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
- addTransformDialectPasses(pipeline);
+ case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
+ SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+ addTransformDialectPasses(
+ pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
break;
+ }
default:
moduleOp.emitOpError("Unsupported pipeline on CPU target.");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 4f876ba..b5bb3c8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -560,10 +560,11 @@
addCPUBufferizePasses(nestedModulePM);
}
-void addTransformDialectPasses(OpPassManager &passManager) {
+void addTransformDialectPasses(OpPassManager &passManager,
+ StringRef entryPoint) {
// Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass());
+ mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
// schedule once applied.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index 74d688d..c964342 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -145,7 +145,8 @@
bool lowerToVectors = true);
/// Transform dialect-based common.
-void addTransformDialectPasses(OpPassManager &passManager);
+void addTransformDialectPasses(OpPassManager &passManager,
+ StringRef entryPoint);
// Populates the passes needed to do tiling, decomposing, and vectorizing the
// convolution ops.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index cef1407..4eb31ca 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -108,9 +108,12 @@
addGPUPackUnPackPasses(pipeline);
break;
// Transform-dialect pipelines.
- case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
- addGPUTransformDialectPasses(pipeline);
+ case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
+ SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+ addGPUTransformDialectPasses(
+ pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
break;
+ }
// no pipeline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
return;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index e529e8a..eddea59 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -597,9 +597,10 @@
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugPayloadTag;
extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugTransformTag;
-void addGPUTransformDialectPasses(OpPassManager &passManager) {
+void addGPUTransformDialectPasses(OpPassManager &passManager,
+ StringRef entryPoint) {
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass());
+ mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 9860197..08f83a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -40,7 +40,7 @@
void addGPUSimpleDistributePassPipeline(OpPassManager &pm);
/// Transform dialect-based path.
-void addGPUTransformDialectPasses(OpPassManager &pm);
+void addGPUTransformDialectPasses(OpPassManager &pm, StringRef entryPoint);
/// Lowering transpose using shared memory.
void addGPUTransposePassPipeline(OpPassManager &pm);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 13f7dcd..a6611e8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -263,9 +263,10 @@
spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
}
-void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
+void addSPIRVTransformDialectPasses(OpPassManager &passManager,
+ StringRef entryPoint) {
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass());
+ mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
@@ -645,8 +646,9 @@
nestedModulePM.addPass(createCSEPass());
}
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {
- addSPIRVTransformDialectPasses(pm);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
+ StringRef entryPoint) {
+ addSPIRVTransformDialectPasses(pm, entryPoint);
// Run GenericVectorization pass additionally to convert vectors into forms
// needed for SPIR-V.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index dc4a22c..84cd3c7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -44,7 +44,8 @@
void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm);
/// Pass pipeline to lower IREE HAL executables via transform dialect schedules.
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
+ StringRef entryPoint);
/// Pass pipeline to lower winograd ops. This pipeline follows the
/// SPIRVBaseVectorize pipeline with the following exception:
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 3943e99..14ec33f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -95,9 +95,12 @@
case CodeGenPipeline::SPIRVWinogradVectorize:
addSPIRVWinogradVectorizePassPipeline(pipeline);
break;
- case CodeGenPipeline::TransformDialectCodegen:
- addSPIRVTransformDialectPassPipeline(pipeline);
+ case CodeGenPipeline::TransformDialectCodegen: {
+ SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+ addSPIRVTransformDialectPassPipeline(
+ pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
break;
+ }
// No pipeline specified, nothing to do.
case CodeGenPipeline::None:
return;
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index c5eab46..95b80af 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -107,28 +107,29 @@
}
/// We test first with threading off so that the printers are legible.
-// R-UN: iree-compile %s --iree-hal-target-backends=vulkan \
-// R-UN: --iree-codegen-use-transform-dialect-strategy=transform_main \
-// R-UN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
-// R-UN: --compile-from=executable-sources \
-// R-UN: --compile-to=executable-targets \
-// R-UN: --mlir-disable-threading | \
-// R-UN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
+// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
+// RUN: --compile-from=executable-sources \
+// RUN: --compile-to=executable-targets \
+// RUN: --mlir-disable-threading | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
-// CODEGEN-PRINTER: IR printer: Setting matmul strategy to default top-level
-// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @transform_main
+// CODEGEN-PRINTER: IR printer: Setting matmul strategy to custom_transform_strategy
+// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
// CODEGEN-PRINTER: IR printer: Setting reduce strategy to base vectorize top-level
// CODEGEN-PRINTER: translation_info = #iree_codegen.translation_info<SPIRVBaseVectorize>, workgroup_size = [16 : index, 1 : index, 1 : index]
/// Then test with threading to make sure it runs
// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
-// RUN: --iree-codegen-use-transform-dialect-strategy=@transform_main \
-// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN: --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
// RUN: --compile-from=executable-sources \
// RUN: --compile-to=executable-targets \
// RUN: --mlir-disable-threading | \
// RUN: FileCheck %s --check-prefixes=CODEGEN
+// CODEGEN: Ran custom_transform_strategy
// CODEGEN: spirv.func @example_module_dispatch_0_generic_80_f32
-// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
+// CODEGEN: hal.executable private @example_module_dispatch_1
+// CODEGEN: #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
+// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
// CODEGEN: spirv.func @example_module_dispatch_2_generic_16x16_f32
diff --git a/samples/transform_dialect/transform_library.mlir b/samples/transform_dialect/transform_library.mlir
index 3bb75ad..d4ca3b4 100644
--- a/samples/transform_dialect/transform_library.mlir
+++ b/samples/transform_dialect/transform_library.mlir
@@ -1,13 +1,76 @@
module attributes { transform.with_named_sequence } {
- // Print and send it down normal IREE codegen.
- transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.consumed}) {
- %1 = transform.structured.generalize %matmul : (!transform.any_op) -> !transform.any_op
- transform.print {name = "Setting matmul strategy to default"}
+ // Example of a custom matmul strategy. The target matmul is annotated with
+ // the name of this strategy down below before strategy selection, overriding
+ // default IREE codegen.
+ transform.named_sequence @custom_transform_strategy(
+ %variant_op: !transform.any_op {transform.consumed}) {
+ // Step 1. Re-match the matmul
+ // ===========================================================================
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+
+ // Step 2. Tile to grid
+ // ===========================================================================
+ %grid_reduction, %forall_grid =
+ transform.structured.tile_using_forall %matmul tile_sizes [16, 16] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
+
+ // Step 3. Vectorize
+ // ===========================================================================
+ %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
+ transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ } : !transform.any_op
+ %func_1 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op
+
+ // Step 4. Bufferize
+ // ===========================================================================
+ transform.apply_patterns to %func_1 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
+ transform.apply_patterns to %func_1 {
+ transform.apply_patterns.tensor.reassociative_reshape_folding
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.iree.apply_cse %func_1 : !transform.any_op
+ transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
+ transform.apply_patterns to %func_1 {
+ transform.apply_patterns.linalg.erase_unnecessary_inputs
+ } : !transform.any_op
+ %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
+ %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+
+ // Step 6. Post-bufferization vector distribution
+ // ===========================================================================
+ %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+ transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
+ transform.iree.map_nested_forall_to_gpu_threads %func_7
+ workgroup_dims = [4, 8, 1] : (!transform.any_op) -> ()
+
+ // Step 7. Do layout analysis and lower to mma
+ // ===========================================================================
+ %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+ %func_11 = transform.iree.layout_analysis_and_distribution %func_10 : (!transform.any_op) -> (!transform.any_op)
+ transform.print {name = "Ran custom_transform_strategy"}
transform.yield
}
- // Send it down subgroup reduce.
- transform.named_sequence @use_subgroup_reduce(%reduce: !transform.any_op {transform.readonly}) {
+ // Send it down a custom transform dialect pipeline.
+ transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.readonly}) {
+ %variant_op = transform.get_parent_op %matmul {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
+ %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %subgroup_reduce = transform.param.constant #iree_codegen.translation_info<TransformDialectCodegen
+ codegen_spec = @custom_transform_strategy> -> !transform.any_param
+ transform.annotate %exports "translation_info" = %subgroup_reduce : !transform.any_op, !transform.any_param
+ transform.print {name = "Setting matmul strategy to custom_transform_strategy"}
+ transform.yield
+ }
+
+ // Send it down subgroup reduce with a custom tiling configuration.
+ transform.named_sequence @use_base_vectorize(%reduce: !transform.any_op {transform.readonly}) {
%variant_op = transform.get_parent_op %reduce {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
%lowering_config = transform.param.constant #iree_codegen.lowering_config<tile_sizes = [[8, 0], [1, 0], [0, 0, 4]]> -> !transform.any_param
transform.annotate %reduce "lowering_config" = %lowering_config : !transform.any_op, !transform.any_param
@@ -42,10 +105,34 @@
transform.yield %matched : !transform.any_op
}
- transform.named_sequence @transform_main(%variant_op: !transform.any_op {transform.consumed}) {
+ // An example of a custom transform dialect based kernel config. Note that
+ // because of the way `transform.foreach_match` works, the callback cannot
+ // manipulate IR beyond the op *given* to the matcher, as foreach_match will
+ // attempt to keep walking the IR even after a successful match. The expected
+ // flow for a strategy like this is as follows:
+ //
+ // Author an entry point like this (@kernel_config) that walks the IR and
+ // attempts to annotate the dispatch with the codegen strategy to use, i.e.
+ // transform.foreach_match in %variant_op
+ // @matcher_0 -> @annotator_0,
+ // @matcher_1 -> @annotator_1,
+ // ...
+ //
+ // the annotators should attach an #iree_codegen.translation_info attribute
+ // to the `hal.executable.export` ops within the variant as well as any
+ // relevant op specific tile sizes (and other important attributes like
+ // workgroup_size and subgroup_size, if relevant). This will then get handed
+ // off to backend specific kernel config, which will let these user configs
+ // pass through unperturbed.
+ //
+ // To couple this with a transform dialect based codegen strategy, the target
+ // codegen strategy can be included inline with this library and relevant ops
+ // can be annotated with `TransformDialectCodegen` as the lowering pipeline,
+ // with a reference to the strategy to use (see an example above).
+ transform.named_sequence @kernel_config(%variant_op: !transform.any_op {transform.consumed}) {
transform.foreach_match in %variant_op
@match_matmul -> @custom_matmul,
- @match_reduce -> @use_subgroup_reduce
+ @match_reduce -> @use_base_vectorize
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}