[Codegen] Add PipelineAttrInterface and PassPipelineAttr (#23590)
This change adds an attribute interface for representing pass pipelines
and a single basic attribute that uses the string based pass interpreter
to populate a pipeline. The intent of this change is NOT to induce a
refactor of all the pass pipelines, instead it's primarily to make
testing structural pipeline changes with partially lowered inputs much
easier. Today if you want to work on a change that affects later stages
of a pass pipeline but will also require changes to earlier steps, it's
hard to stage those changes since there isn't a convenient way to jump
into the middle of a codegen pass pipeline (unlike the rest of the
compiler which offers distinct stages).
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
index ee77447..d0c06f7 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
@@ -115,6 +115,7 @@
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
index 4870a75..d26c561 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
@@ -66,6 +66,7 @@
MLIRLinalgDialect
MLIRMemRefDialect
MLIRParser
+ MLIRPass
MLIRSCFDialect
MLIRSCFTransforms
MLIRSupport
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 9488946..441e612 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -19,6 +19,50 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/StorageUniquerSupport.h"
+#include "mlir/Pass/PassRegistry.h"
+
+// Custom parse/print directives for TranslationInfoAttr's pipeline field.
+// These must be defined before the generated .cpp.inc is included because
+// the ODS-generated parse/print methods call them.
+namespace mlir::iree_compiler::IREE::Codegen {
+
+/// Parses either a DispatchLoweringPassPipeline enum keyword (e.g.,
+/// `CPUDefault`) or a generic attribute implementing PipelineAttrInterface
+/// (e.g., `#iree_codegen.pass_pipeline<"canonicalize">`).
+static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) {
+ StringRef keyword;
+ SMLoc loc = parser.getCurrentLocation();
+ if (succeeded(parser.parseOptionalKeyword(&keyword))) {
+ std::optional<DispatchLoweringPassPipeline> pipeline =
+ symbolizeDispatchLoweringPassPipeline(keyword);
+ if (!pipeline) {
+ parser.emitError(loc, "unknown pipeline keyword: ") << keyword;
+ return failure();
+ }
+ result =
+ DispatchLoweringPassPipelineAttr::get(parser.getContext(), *pipeline);
+ return success();
+ }
+ Attribute attr;
+ if (parser.parseAttribute(attr)) {
+ return failure();
+ }
+ result = attr;
+ return success();
+}
+
+/// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other
+/// attributes (e.g., PipelineAttrInterface impls) via the generic printer.
+static void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) {
+ if (auto enumAttr =
+ dyn_cast<DispatchLoweringPassPipelineAttr>(pipelineAttr)) {
+ printer << stringifyEnum(enumAttr.getValue());
+ return;
+ }
+ printer.printAttribute(pipelineAttr);
+}
+
+} // namespace mlir::iree_compiler::IREE::Codegen
#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp.inc"
@@ -65,6 +109,28 @@
}
//===----------------------------------------------------------------------===//
+// iree_codegen.pass_pipeline
+//===----------------------------------------------------------------------===//
+
+LogicalResult PassPipelineAttr::buildPipeline(OpPassManager &pm) const {
+ if (failed(parsePassPipeline(getPipeline(), pm))) {
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult
+PassPipelineAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ StringRef pipeline) {
+ OpPassManager pm("builtin.module");
+ if (failed(parsePassPipeline(pipeline, pm))) {
+ return emitError() << "invalid pass pipeline specification: '" << pipeline
+ << "'";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// iree_codegen.translation_info
//===----------------------------------------------------------------------===//
@@ -72,7 +138,7 @@
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
SymbolRefAttr codegenSpec, ArrayRef<int64_t> workgroupSize,
std::optional<int64_t> subgroupSize, DictionaryAttr configuration) {
- auto pipelineAttr =
+ Attribute pipelineAttr =
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
return get(context, pipelineAttr, codegenSpec, workgroupSize,
subgroupSize.value_or(int64_t()), configuration);
@@ -82,7 +148,7 @@
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
ArrayRef<int64_t> workgroupSize, std::optional<int64_t> subgroupSize,
DictionaryAttr configuration) {
- auto pipelineAttr =
+ Attribute pipelineAttr =
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
return get(context, pipelineAttr, /*codegenSpec=*/SymbolRefAttr(),
workgroupSize, subgroupSize.value_or(int64_t()), configuration);
@@ -90,28 +156,38 @@
DispatchLoweringPassPipeline
TranslationInfoAttr::getDispatchLoweringPassPipeline() {
- return getPassPipeline().getValue();
+ if (auto enumAttr =
+ dyn_cast<DispatchLoweringPassPipelineAttr>(getPassPipeline())) {
+ return enumAttr.getValue();
+ }
+ return DispatchLoweringPassPipeline::None;
}
LogicalResult TranslationInfoAttr::verify(
- function_ref<InFlightDiagnostic()> emitError,
- IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline,
+ function_ref<InFlightDiagnostic()> emitError, Attribute passPipeline,
SymbolRefAttr codegenSpec, ArrayRef<int64_t> workgroupSize,
int64_t subgroupSize, DictionaryAttr configuration) {
if (!passPipeline) {
return emitError() << "missing pass pipeline specification";
}
- auto passPipelineValue = passPipeline.getValue();
- if (passPipelineValue > IREE::Codegen::DispatchLoweringPassPipeline::None) {
- return emitError() << "invalid pass pipeline value : "
- << stringifyEnum(passPipeline.getValue());
- }
- auto tdPassPipeline =
- IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
- if (codegenSpec && passPipelineValue != tdPassPipeline) {
+ if (auto enumAttr =
+ dyn_cast<DispatchLoweringPassPipelineAttr>(passPipeline)) {
+ DispatchLoweringPassPipeline passPipelineValue = enumAttr.getValue();
+ if (passPipelineValue > IREE::Codegen::DispatchLoweringPassPipeline::None) {
+ return emitError() << "invalid pass pipeline value : "
+ << stringifyEnum(passPipelineValue);
+ }
+ DispatchLoweringPassPipeline tdPassPipeline =
+ IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
+ if (codegenSpec && passPipelineValue != tdPassPipeline) {
+ return emitError()
+ << "transform dialect codegen spec requires pass pipeline : "
+ << stringifyEnum(tdPassPipeline);
+ }
+ } else if (!isa<PipelineAttrInterface>(passPipeline)) {
return emitError()
- << "transform dialect codegen spec requires pass pipeline : "
- << stringifyEnum(tdPassPipeline);
+ << "pass pipeline must be a DispatchLoweringPassPipelineAttr or "
+ "implement PipelineAttrInterface";
}
if (workgroupSize.size() > 3) {
return emitError() << "workgroup size cannot have more than 3 entries";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index ce39809..c43a48a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -251,6 +251,32 @@
//===---------------------------------------------------------------------===//
+// iree_codegen.pass_pipeline
+//===---------------------------------------------------------------------===//
+
+def IREECodegen_PassPipelineAttr :
+ AttrDef<IREECodegen_Dialect, "PassPipeline", [
+ DeclareAttrInterfaceMethods<IREECodegen_PipelineAttrInterface, [
+ "buildPipeline"
+ ]>
+ ]> {
+ let mnemonic = "pass_pipeline";
+ let summary = "An attribute carrying a textual pass pipeline string.";
+ let description = [{
+ Specifies a pass pipeline using MLIR's textual pass pipeline syntax.
+ The pipeline string is parsed and populated into an OpPassManager
+ when `buildPipeline` is called.
+ }];
+ let parameters = (ins
+ StringRefParameter<"The textual pass pipeline specification">:$pipeline
+ );
+ let assemblyFormat = [{
+ `<` $pipeline `>`
+ }];
+ let genVerifyDecl = 1;
+}
+
+//===---------------------------------------------------------------------===//
// iree_codegen.translation_info
//===---------------------------------------------------------------------===//
@@ -269,22 +295,15 @@
dispatch region (like `linalg.matmul`/`linalg.*conv*`), this
attribute gets propagated to the entry point function.
- The fields are
- - `passPipeline` : The pass pipeline to use.
-
- }];
-
- let assemblyFormat = [{
- `<` `pipeline` `=` `` $passPipeline
- (`codegen_spec` `=` $codegenSpec^)?
- (`workgroup_size` `=` `[` $workgroupSize^ `]`)?
- (`subgroup_size` `=` $subgroupSize^)?
- (`,` $configuration^)? `>`
+ The `passPipeline` field can be either:
+ - A `DispatchLoweringPassPipelineAttr` (enum keyword like `CPUDefault`).
+ - Any attribute implementing `PipelineAttrInterface` (e.g.,
+ `#iree_codegen.pass_pipeline<"...">`).
}];
let parameters = (ins
- AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr",
- "Name of the pipeline to be invoked on the translation unit.">:$passPipeline,
+ AttrParameter<"Attribute",
+ "Pass pipeline specification.">:$passPipeline,
OptionalParameter<"SymbolRefAttr",
"The symbol pointing to the transform dialect codegen spec to be used">:$codegenSpec,
OptionalArrayRefParameter<"int64_t", "The workgroup size to use">:$workgroupSize,
@@ -304,9 +323,20 @@
CArg<"DictionaryAttr", "{}">:$configuration)>
];
let extraClassDeclaration = [{
- // Returns the lowering pass pipeline set.
+ // Returns the lowering pass pipeline enum value. Returns None if the
+ // pipeline is not a DispatchLoweringPassPipelineAttr.
DispatchLoweringPassPipeline getDispatchLoweringPassPipeline();
}];
+
+ let assemblyFormat = [{
+ `<` `pipeline` `=` custom<PipelineAttr>($passPipeline)
+ (`codegen_spec` `=` $codegenSpec^)?
+ (`workgroup_size` `=` `[` $workgroupSize^ `]`)?
+ (`subgroup_size` `=` $subgroupSize^)?
+ (`,` $configuration^)?
+ `>`
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h
index 2b6a162..5c4713a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h
@@ -14,6 +14,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
index 7711c2b..82c7471 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
@@ -752,6 +752,26 @@
];
}
+def IREECodegen_PipelineAttrInterface :
+ AttrInterface<"PipelineAttrInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
+ let description = [{
+ Attribute interface for building a pass pipeline. Implementations populate
+ the provided OpPassManager with the desired pass pipeline.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populates the given pass manager with a pass pipeline.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"buildPipeline",
+ /*args=*/(ins "::mlir::OpPassManager &":$pm)
+ >
+ ];
+}
+
def IREECodegen_TargetInfoAttrInterface :
AttrInterface<"TargetInfoAttrInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir
index 439bbf2..7658a73 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir
@@ -125,3 +125,37 @@
}
}
}
+
+// -----
+
+module {
+ /// Pass pipeline attribute round-trips correctly.
+ func.func @test_pass_pipeline() attributes {
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>} {
+ return
+ }
+}
+// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
+
+// -----
+
+module {
+ /// Pass pipeline attribute with workgroup size and subgroup size round-trips.
+ func.func @test_pass_pipeline_with_config() attributes {
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize"> workgroup_size = [64, 1, 1] subgroup_size = 32>} {
+ return
+ }
+}
+// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize"> workgroup_size = [64, 1, 1] subgroup_size = 32>
+
+// -----
+
+module {
+ /// Invalid pass pipeline string should be caught at verify time.
+ func.func @invalid_pass_pipeline() attributes {
+ // expected-error @+1 {{invalid pass pipeline specification: 'not_a_real_pass'}}
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"not_a_real_pass">>
+ } {
+ return
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index a9c872e..93edf3f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -3976,9 +3976,10 @@
}
}
- auto tInfo = getTranslationInfo(entryPointFn);
- auto pipeline = tInfo.getPassPipeline().getValue();
- auto pipelineConfig = tInfo.getConfiguration();
+ IREE::Codegen::TranslationInfoAttr tInfo = getTranslationInfo(entryPointFn);
+ DispatchLoweringPassPipeline pipeline =
+ tInfo.getDispatchLoweringPassPipeline();
+ DictionaryAttr pipelineConfig = tInfo.getConfiguration();
if (isOptEnabled(entryPointFn, getEnableLoopPeelingStr())) {
// See #16406
LDBG() << "unpack fusion does not work with peeling, falling back to "
@@ -4167,7 +4168,8 @@
// The transform dialect codegen has different logics and codegen flow.
// Ignore the tile sizes adjustment.
- auto pipeline = getTranslationInfo(entryPointFn).getPassPipeline().getValue();
+ DispatchLoweringPassPipeline pipeline =
+ getTranslationInfo(entryPointFn).getDispatchLoweringPassPipeline();
if (pipeline != DispatchLoweringPassPipeline::TransformDialectCodegen) {
if (failed(adjustTileSizesForRootUnPackOp(entryPointFn, rootOperation))) {
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 1670e6a..5864295 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -120,7 +120,6 @@
return;
}
- auto pipeline = translationInfo.getDispatchLoweringPassPipeline();
LLVMCPUPipelineOptions pipelineOpts;
pipelineOpts.cpuOpts = cpuOptions.getValue();
@@ -170,45 +169,57 @@
LoweringConfigAttrInterface loweringConfig = getRootLoweringConfig(funcOp);
OpPassManager passManager(func::FuncOp::getOperationName());
- switch (pipeline) {
- // No pipeline specified, nothing to do.
- case IREE::Codegen::DispatchLoweringPassPipeline::None:
- return;
- case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault: {
- addCPUDefaultPassPipeline(passManager, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::
- CPUBufferOpsTileAndVectorize: {
- addCPUBufferOpsTileAndVectorizePipeline(passManager, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert: {
- assert(loweringConfig && "expected a valid lowering config");
- addMultiTilingExpertPassPipeline(passManager, loweringConfig, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::
- CPUConvTileAndDecomposeExpert: {
- addConvTileAndDecomposeExpertPassPipeline(passManager, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::Mmt4dTilingExpert: {
- addMmt4dTilingExpertPassPipeline(passManager, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling: {
- addCPUDataTilingPipeline(passManager, pipelineOpts);
- break;
- }
- case IREE::Codegen::DispatchLoweringPassPipeline::
- CPULinalgExtTileAndVectorize: {
- addCPULinalgExtTileAndVectorizePipeline(passManager, pipelineOpts);
- break;
- }
- default:
- funcOp.emitOpError("Unsupported pipeline on CPU target.");
- return signalPassFailure();
+
+ // Check for a custom pipeline via PipelineAttrInterface.
+ Attribute pipelineAttr = translationInfo.getPassPipeline();
+ if (auto customPipeline =
+ dyn_cast<IREE::Codegen::PipelineAttrInterface>(pipelineAttr)) {
+ if (failed(customPipeline.buildPipeline(passManager))) {
+ funcOp.emitOpError("failed to build custom pass pipeline");
+ return signalPassFailure();
+ }
+ } else {
+ switch (translationInfo.getDispatchLoweringPassPipeline()) {
+ // No pipeline specified, nothing to do.
+ case IREE::Codegen::DispatchLoweringPassPipeline::None:
+ return;
+ case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault: {
+ addCPUDefaultPassPipeline(passManager, pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::
+ CPUBufferOpsTileAndVectorize: {
+ addCPUBufferOpsTileAndVectorizePipeline(passManager, pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert: {
+ assert(loweringConfig && "expected a valid lowering config");
+ addMultiTilingExpertPassPipeline(passManager, loweringConfig,
+ pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::
+ CPUConvTileAndDecomposeExpert: {
+ addConvTileAndDecomposeExpertPassPipeline(passManager, pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::Mmt4dTilingExpert: {
+ addMmt4dTilingExpertPassPipeline(passManager, pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling: {
+ addCPUDataTilingPipeline(passManager, pipelineOpts);
+ break;
+ }
+ case IREE::Codegen::DispatchLoweringPassPipeline::
+ CPULinalgExtTileAndVectorize: {
+ addCPULinalgExtTileAndVectorizePipeline(passManager, pipelineOpts);
+ break;
+ }
+ default:
+ funcOp.emitOpError("Unsupported pipeline on CPU target.");
+ return signalPassFailure();
+ }
}
if (failed(runPipeline(passManager, funcOp))) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
index 5948935..ccf6397 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
@@ -258,11 +258,18 @@
return signalPassFailure();
}
- auto translationInfo = getTranslationInfo(funcOp);
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(funcOp);
if (!translationInfo) {
continue;
}
+ // Custom pipelines via PipelineAttrInterface skip enum-based verification.
+ if (isa<IREE::Codegen::PipelineAttrInterface>(
+ translationInfo.getPassPipeline())) {
+ continue;
+ }
+
// Verify the configuration.
LogicalResult verificationStatus = success();
switch (translationInfo.getDispatchLoweringPassPipeline()) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel
index 17995c2..8ad5b50 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel
@@ -28,6 +28,7 @@
"check_ir_before_llvm_conversion.mlir",
"check_ir_before_llvm_conversion_not_fail_unbound.mlir",
"convert_to_llvm.mlir",
+ "custom_pass_pipeline.mlir",
"emit_vectorization_remarks.mlir",
"expand_f16_op_to_f32.mlir",
"hal_executable_constants.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index fe8cae4..a7fdb8f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -23,6 +23,7 @@
"check_ir_before_llvm_conversion.mlir"
"check_ir_before_llvm_conversion_not_fail_unbound.mlir"
"convert_to_llvm.mlir"
+ "custom_pass_pipeline.mlir"
"emit_vectorization_remarks.mlir"
"expand_f16_op_to_f32.mlir"
"hal_executable_constants.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/custom_pass_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/custom_pass_pipeline.mlir
new file mode 100644
index 0000000..cf1cbcf
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/custom_pass_pipeline.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmcpu-lower-executable-target))' %s | FileCheck %s
+
+// Verify that a custom pass pipeline specified via #iree_codegen.pass_pipeline
+// attribute is executed by the LLVMCPU lower executable target pass.
+
+#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ cpu_features = "",
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 16 : index,
+ target_triple = "x86_64-none-elf"
+}>
+
+// The arith.addi with zero should be folded away by canonicalize.
+func.func @test_custom_pipeline(%arg0: index) -> index attributes {
+ hal.executable.target = #executable_target,
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = arith.addi %arg0, %c0 : index
+ return %0 : index
+}
+// CHECK-LABEL: func.func @test_custom_pipeline
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NEXT: return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index f09091d..2c7d003 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -87,34 +87,44 @@
IREE::GPU::GPUPipelineOptions pipelineOptions =
IREE::GPU::getPipelineOptions(funcOp, translationInfo);
- switch (translationInfo.getDispatchLoweringPassPipeline()) {
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDefault:
- addGPUDefaultPassPipeline(pipeline, pipelineOptions);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering:
- addGPUBaseLoweringPassPipeline(pipeline);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute:
- addGPUSimpleDistributePassPipeline(pipeline);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorize:
- addGPUVectorizationPassPipeline(pipeline);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
- addGPUWinogradVectorizePassPipeline(pipeline);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute:
- addGPUVectorDistributePassPipeline(pipeline, pipelineOptions, forROCDL);
- break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse:
- addGPUTileAndFusePassPipeline(pipeline, pipelineOptions, forROCDL);
- break;
- // no pipeline specified, nothing to do.
- case IREE::Codegen::DispatchLoweringPassPipeline::None:
- return;
- default:
- funcOp.emitOpError("unsupported pipeline on GPU target.");
- return signalPassFailure();
+ // Check for a custom pipeline via PipelineAttrInterface.
+ Attribute pipelineAttr = translationInfo.getPassPipeline();
+ if (auto customPipeline =
+ dyn_cast<IREE::Codegen::PipelineAttrInterface>(pipelineAttr)) {
+ if (failed(customPipeline.buildPipeline(pipeline))) {
+ funcOp.emitOpError("failed to build custom pass pipeline");
+ return signalPassFailure();
+ }
+ } else {
+ switch (translationInfo.getDispatchLoweringPassPipeline()) {
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDefault:
+ addGPUDefaultPassPipeline(pipeline, pipelineOptions);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering:
+ addGPUBaseLoweringPassPipeline(pipeline);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute:
+ addGPUSimpleDistributePassPipeline(pipeline);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorize:
+ addGPUVectorizationPassPipeline(pipeline);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
+ addGPUWinogradVectorizePassPipeline(pipeline);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute:
+ addGPUVectorDistributePassPipeline(pipeline, pipelineOptions, forROCDL);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse:
+ addGPUTileAndFusePassPipeline(pipeline, pipelineOptions, forROCDL);
+ break;
+ // No pipeline specified, nothing to do.
+ case IREE::Codegen::DispatchLoweringPassPipeline::None:
+ return;
+ default:
+ funcOp.emitOpError("unsupported pipeline on GPU target.");
+ return signalPassFailure();
+ }
}
if (failed(runPipeline(pipeline, funcOp))) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
index 804f4d8..1a6d6d9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "mlir/Pass/Pass.h"
@@ -88,6 +89,12 @@
return;
}
+ // Custom pipelines via PipelineAttrInterface skip enum-based verification.
+ if (isa<IREE::Codegen::PipelineAttrInterface>(
+ translationInfo.getPassPipeline())) {
+ continue;
+ }
+
// Verify the properties of each entry point based on the target pipeline.
if (failed(verifyEntryPoint(funcOp, translationInfo))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 571162b..a4ecee8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -37,6 +37,7 @@
"convert_to_rocdl_gfx950.mlir",
"create_async_groups.mlir",
"create_tile_sizes.mlir",
+ "custom_pass_pipeline.mlir",
"distribute_to_thread.mlir",
"elementwise_pipeline.mlir",
"extract_address_computation_gpu.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index d91bc80..91eb73e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -32,6 +32,7 @@
"convert_to_rocdl_gfx950.mlir"
"create_async_groups.mlir"
"create_tile_sizes.mlir"
+ "custom_pass_pipeline.mlir"
"distribute_to_thread.mlir"
"elementwise_pipeline.mlir"
"extract_address_computation_gpu.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/custom_pass_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/custom_pass_pipeline.mlir
new file mode 100644
index 0000000..385acee
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/custom_pass_pipeline.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-lower-executable-target))' %s | FileCheck %s
+
+// Verify that a custom pass pipeline specified via #iree_codegen.pass_pipeline
+// attribute is executed by the LLVMGPU lower executable target pass.
+
+#executable_target = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+
+// The arith.addi with zero should be folded away by canonicalize.
+func.func @test_custom_pipeline(%arg0: index) -> index attributes {
+ hal.executable.target = #executable_target,
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = arith.addi %arg0, %c0 : index
+ return %0 : index
+}
+// CHECK-LABEL: func.func @test_custom_pipeline
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NEXT: return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 969cc28..4c040c4 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -13,7 +13,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -80,63 +80,70 @@
}
OpPassManager &pipeline = maybePipeline.value();
- switch (translationInfo.getDispatchLoweringPassPipeline()) {
- case CodeGenPipeline::SPIRVBaseLowering:
- addSPIRVBaseLoweringPassPipeline(pipeline);
- break;
- case CodeGenPipeline::SPIRVBaseDistribute:
- addSPIRVBaseDistributePassPipeline(pipeline);
- break;
- case CodeGenPipeline::SPIRVBaseVectorize:
- addSPIRVBaseVectorizePassPipeline(pipeline);
- break;
- case CodeGenPipeline::SPIRVSubgroupReduce:
- addSPIRVSubgroupReducePassPipeline(pipeline);
- break;
- case CodeGenPipeline::SPIRVCooperativeMatrixVectorize: {
- FailureOr<int64_t> maybeDepth =
- getSoftwarePipelineDepth(translationInfo.getConfiguration());
- FailureOr<int64_t> maybeStage =
- getSoftwarePipelineStoreStage(translationInfo.getConfiguration());
- if (failed(maybeDepth) || failed(maybeStage)) {
- funcOp.emitOpError("invalid cooperative matrix pipeline without "
- "software pipelining configuration.");
+ // Check for a custom pipeline via PipelineAttrInterface.
+ Attribute pipelineAttr = translationInfo.getPassPipeline();
+ if (auto customPipeline =
+ dyn_cast<IREE::Codegen::PipelineAttrInterface>(pipelineAttr)) {
+ if (failed(customPipeline.buildPipeline(pipeline))) {
+ funcOp.emitOpError("failed to build custom pass pipeline");
return signalPassFailure();
}
- addSPIRVCooperativeMatrixVectorizePassPipeline(pipeline, *maybeDepth,
- *maybeStage);
- break;
- }
- case CodeGenPipeline::SPIRVMatmulPromoteVectorize: {
- FailureOr<int64_t> maybeDepth =
- getSoftwarePipelineDepth(translationInfo.getConfiguration());
- FailureOr<int64_t> maybeStage =
- getSoftwarePipelineStoreStage(translationInfo.getConfiguration());
- if (failed(maybeDepth) || failed(maybeStage)) {
- funcOp.emitOpError("invalid matmul pipeline without software "
- "pipelining configuration.");
+ } else {
+ switch (translationInfo.getDispatchLoweringPassPipeline()) {
+ case CodeGenPipeline::SPIRVBaseLowering:
+ addSPIRVBaseLoweringPassPipeline(pipeline);
+ break;
+ case CodeGenPipeline::SPIRVBaseDistribute:
+ addSPIRVBaseDistributePassPipeline(pipeline);
+ break;
+ case CodeGenPipeline::SPIRVBaseVectorize:
+ addSPIRVBaseVectorizePassPipeline(pipeline);
+ break;
+ case CodeGenPipeline::SPIRVSubgroupReduce:
+ addSPIRVSubgroupReducePassPipeline(pipeline);
+ break;
+ case CodeGenPipeline::SPIRVCooperativeMatrixVectorize: {
+ FailureOr<int64_t> maybeDepth =
+ getSoftwarePipelineDepth(translationInfo.getConfiguration());
+ FailureOr<int64_t> maybeStage =
+ getSoftwarePipelineStoreStage(translationInfo.getConfiguration());
+ if (failed(maybeDepth) || failed(maybeStage)) {
+ funcOp.emitOpError("invalid cooperative matrix pipeline without "
+ "software pipelining configuration.");
+ return signalPassFailure();
+ }
+ addSPIRVCooperativeMatrixVectorizePassPipeline(pipeline, *maybeDepth,
+ *maybeStage);
+ break;
+ }
+ case CodeGenPipeline::SPIRVMatmulPromoteVectorize: {
+ FailureOr<int64_t> maybeDepth =
+ getSoftwarePipelineDepth(translationInfo.getConfiguration());
+ FailureOr<int64_t> maybeStage =
+ getSoftwarePipelineStoreStage(translationInfo.getConfiguration());
+ if (failed(maybeDepth) || failed(maybeStage)) {
+ funcOp.emitOpError("invalid matmul pipeline without software "
+ "pipelining configuration.");
+ return signalPassFailure();
+ }
+ addSPIRVMatmulPromoteVectorizePassPipeline(pipeline, *maybeDepth,
+ *maybeStage);
+ break;
+ }
+ case CodeGenPipeline::SPIRVWinogradVectorize:
+ addSPIRVWinogradVectorizePassPipeline(pipeline);
+ break;
+ // No pipeline specified, nothing to do.
+ case CodeGenPipeline::None:
+ return;
+ default:
+ funcOp.emitOpError("unsupported pipeline on GPU target.");
return signalPassFailure();
}
- addSPIRVMatmulPromoteVectorizePassPipeline(pipeline, *maybeDepth,
- *maybeStage);
- break;
- }
- case CodeGenPipeline::SPIRVWinogradVectorize:
- addSPIRVWinogradVectorizePassPipeline(pipeline);
- break;
- // No pipeline specified, nothing to do.
- case CodeGenPipeline::None:
- return;
- default:
- funcOp.emitOpError("unsupported pipeline on GPU target.");
- return signalPassFailure();
}
- LLVM_DEBUG({
- llvm::dbgs() << "Using SPIR-V lowering pass pipeline:\n";
- pipeline.printAsTextualPipeline(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "Using SPIR-V lowering pass pipeline: ";
+ LLVM_DEBUG(pipeline.dump());
if (failed(runPipeline(pipeline, funcOp))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
index 9028aa7..330ac7f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
@@ -95,11 +96,18 @@
return signalPassFailure();
}
- auto translationInfo = getTranslationInfo(funcOp);
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(funcOp);
if (!translationInfo) {
continue;
}
+ // Custom pipelines via PipelineAttrInterface skip enum-based verification.
+ if (isa<IREE::Codegen::PipelineAttrInterface>(
+ translationInfo.getPassPipeline())) {
+ continue;
+ }
+
// Verify the properties of each entry point based on the target pipeline.
if (failed(verifyTranslationInfo(funcOp, translationInfo))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index 47d4524..03374b1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -41,6 +41,7 @@
"config_user.mlir",
"convert_gpu_target.mlir",
"convert_to_spirv.mlir",
+ "custom_pass_pipeline.mlir",
"emulate_i64.mlir",
"erase_storage_buffer_static_shape.mlir",
"illegal_configuration.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 6ce2ecc..bb7a0ef 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -36,6 +36,7 @@
"config_user.mlir"
"convert_gpu_target.mlir"
"convert_to_spirv.mlir"
+ "custom_pass_pipeline.mlir"
"emulate_i64.mlir"
"erase_storage_buffer_static_shape.mlir"
"illegal_configuration.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/custom_pass_pipeline.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/custom_pass_pipeline.mlir
new file mode 100644
index 0000000..b26f43b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/custom_pass_pipeline.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-spirv-lower-executable-target-pass))' %s | FileCheck %s
+
+// Verify that a custom pass pipeline specified via #iree_codegen.pass_pipeline
+// attribute is executed by the SPIRV lower executable target pass.
+
+#executable_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
+
+// The arith.addi with zero should be folded away by canonicalize.
+func.func @test_custom_pipeline(%arg0: index) -> index attributes {
+ hal.executable.target = #executable_target,
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = arith.addi %arg0, %c0 : index
+ return %0 : index
+}
+// CHECK-LABEL: func.func @test_custom_pipeline
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NEXT: return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
index 84d4091..aa8299c 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
@@ -12,7 +12,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -30,10 +30,12 @@
namespace {
/// Lowers an hal.executable.variant operation to scalar/native-vector code.
-class VMVXLowerExecutableTargetPass
+class VMVXLowerExecutableTargetPass final
: public impl::VMVXLowerExecutableTargetPassBase<
VMVXLowerExecutableTargetPass> {
public:
+ using Base::Base;
+
void getDependentDialects(DialectRegistry ®istry) const override {
// clang-format off
registry.insert<IREE::HAL::HALDialect,
@@ -53,7 +55,8 @@
void VMVXLowerExecutableTargetPass::runOnOperation() {
mlir::FunctionOpInterface funcOp = getOperation();
- auto translationInfo = getTranslationInfo(funcOp);
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(funcOp);
if (!translationInfo) {
return;
}
@@ -67,24 +70,32 @@
}
OpPassManager &pipeline = maybePipeline.value();
- auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
- bool enableUKernels = target && hasUkernel(target.getConfiguration());
- switch (translationInfo.getDispatchLoweringPassPipeline()) {
- // No pipeline specified, nothing to do.
- case IREE::Codegen::DispatchLoweringPassPipeline::None:
- return;
- case IREE::Codegen::DispatchLoweringPassPipeline::VMVXDefault:
- addVMVXDefaultPassPipeline(pipeline, enableUKernels);
- break;
- default:
- funcOp.emitOpError("Unsupported pipeline on VMVX target.");
- return signalPassFailure();
+ // Check for a custom pipeline via PipelineAttrInterface.
+ Attribute pipelineAttr = translationInfo.getPassPipeline();
+ if (auto customPipeline =
+ dyn_cast<IREE::Codegen::PipelineAttrInterface>(pipelineAttr)) {
+ if (failed(customPipeline.buildPipeline(pipeline))) {
+ funcOp.emitOpError("failed to build custom pass pipeline");
+ return signalPassFailure();
+ }
+ } else {
+ auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
+ bool enableUKernels = target && hasUkernel(target.getConfiguration());
+ switch (translationInfo.getDispatchLoweringPassPipeline()) {
+ // No pipeline specified, nothing to do.
+ case IREE::Codegen::DispatchLoweringPassPipeline::None:
+ return;
+ case IREE::Codegen::DispatchLoweringPassPipeline::VMVXDefault:
+ addVMVXDefaultPassPipeline(pipeline, enableUKernels);
+ break;
+ default:
+ funcOp.emitOpError("Unsupported pipeline on VMVX target.");
+ return signalPassFailure();
+ }
}
- LLVM_DEBUG({
- llvm::dbgs() << "Using Pass pipeline : ";
- pipeline.dump();
- });
+ LDBG() << "Using pass pipeline: ";
+ LLVM_DEBUG(pipeline.dump());
if (failed(runPipeline(pipeline, funcOp))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD.bazel
index 9990580..a3bfda3 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD.bazel
@@ -20,6 +20,7 @@
# keep sorted
[
"assign_constant_ordinals.mlir",
+ "custom_pass_pipeline.mlir",
"link_executables.mlir",
"lower_linalg_microkernels.mlir",
"pipeline.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt
index 0d8c673..49961ce 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"assign_constant_ordinals.mlir"
+ "custom_pass_pipeline.mlir"
"link_executables.mlir"
"lower_linalg_microkernels.mlir"
"pipeline.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/custom_pass_pipeline.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/custom_pass_pipeline.mlir
new file mode 100644
index 0000000..a4e6579
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/custom_pass_pipeline.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-vmvx-lower-executable-target))' %s | FileCheck %s
+
+// Verify that a custom pass pipeline specified via #iree_codegen.pass_pipeline
+// attribute is executed by the VMVX lower executable target pass.
+
+#executable_target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+
+// The arith.addi with zero should be folded away by canonicalize.
+func.func @test_custom_pipeline(%arg0: index) -> index attributes {
+ hal.executable.target = #executable_target,
+ translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = arith.addi %arg0, %c0 : index
+ return %0 : index
+}
+// CHECK-LABEL: func.func @test_custom_pipeline
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NEXT: return %[[ARG0]]