convert-linalg-matmul-to-mmt4d improvements (#8192)
- take target info, not M0/K0/N0 values.
- share the target info stuff with VectorContractCustomKernels
- better pass options (enum and list instead of bag of bools)
- add also a enable_generic_slow bool option controlling whether
to do mmt4d even in cases where we dont have a fast kernel
(for tests) or not (for real users, the default).
Also: trim e2e matmul tests a bit:
- do not test mmt4d on vmvx for now (very slow to run, has not proven
to find more bugs, and not a current focus)
- for cpu-feature-specific variants, only generate tests for data types
that are concerned by the variant (aarch64:+dotprod -> i8 only)
diff --git a/build_tools/bazel/iree_trace_runner_test.bzl b/build_tools/bazel/iree_trace_runner_test.bzl
index c0e6b79..a6fb7d0 100644
--- a/build_tools/bazel/iree_trace_runner_test.bzl
+++ b/build_tools/bazel/iree_trace_runner_test.bzl
@@ -207,7 +207,11 @@
opt_tool: Defaulting to iree-opt. Tool used to preprocess the source files
if opt_flags is specified.
opt_flags: If specified, source files are preprocessed with opt_tool with
- these flags.
+ these flags. The special string "#pass_options_variant#" is replaced
+ with the empty string. That may in the future be changed to some
+ automatically determined pass options for each entry in
+ target_cpu_features_variants, as is currently done in the CMake
+ build.
trace_runner: trace-runner program to run.
timeout: timeout for the generated tests.
target_cpu_features_variants: list of target cpu features variants. Currently unimplemented, so each
@@ -221,6 +225,7 @@
fail("Entry %s in target_cpu_features_variants: unimplemented" % target_cpu_features)
tests = []
+ processed_opt_flags = [flag.replace("#pass_options_variant#", "") for flag in opt_flags]
for backend, driver in target_backends_and_drivers:
suite_entry_name = "_".join([name, backend, driver])
iree_single_backend_generated_trace_runner_test(
@@ -233,7 +238,7 @@
compiler_flags = compiler_flags,
runner_args = runner_args,
opt_tool = opt_tool,
- opt_flags = opt_flags,
+ opt_flags = processed_opt_flags,
tags = tags,
timeout = timeout,
**kwargs
diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake
index f918702..125670c 100644
--- a/build_tools/cmake/iree_bytecode_module.cmake
+++ b/build_tools/cmake/iree_bytecode_module.cmake
@@ -91,6 +91,7 @@
DEPENDS
${_OPT_TOOL_EXECUTABLE}
${_RULE_SRC}
+ VERBATIM
)
else()
# OPT_FLAGS was not specified, so are not using the OPT_TOOL.
@@ -123,6 +124,7 @@
${_TRANSLATE_TOOL_EXECUTABLE}
${_EMBEDDED_LINKER_TOOL_EXECUTABLE}
${_TRANSLATE_SRC}
+ VERBATIM
)
if(_RULE_TESTONLY)
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 97cf9fd..9492f42 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -294,25 +294,32 @@
# Helper function parsing a string occurring as an entry in TARGET_CPU_FEATURES_VARIANTS.
#
# This function has 3 output-params: variables that it sets with PARENT_SCOPE:
-# _ENABLED, _TARGET_CPU_FEATURES, _TARGET_CPU_FEATURES_SUFFIX.
+# _ENABLED, _TARGET_CPU_FEATURES, _TARGET_CPU_FEATURES_SUFFIX, _TARGET_PASS_OPTIONS.
#
# "default" is handled specially. _ENABLED is always set to "TRUE" and
-# _TARGET_CPU_FEATURES and _TARGET_CPU_FEATURES_SUFFIX are both set to the
-# empty string.
+# _TARGET_CPU_FEATURES, _TARGET_CPU_FEATURES_SUFFIX and _TARGET_PASS_OPTIONS are set to
+# the empty string.
#
# Other values are parsed as "arch:features", the parsed arch is matched with
# `CMAKE_SYSTEM_PROCESSOR`, `_ENABLED` is set to "TRUE" if and only if they
-# match, and `_TARGET_CPU_FEATURES_SUFFIX` is set to a string based on the
-# features that is appropriate to include in a CMake target or test name. More
-# than one target cpu feature is currently unsupported.
-# aarch64:+dotprod -> _TARGET_CPU_FEATURES="+dotprod", _TARGET_CPU_FEATURES_SUFFIX="_dotprod"
-# default -> _TARGET_CPU_FEATURES="", _TARGET_CPU_FEATURES_SUFFIX="", ENABLED="TRUE"
+# match, `_TARGET_CPU_FEATURES_SUFFIX` is set to a string based on the
+# features that is appropriate to include in a CMake target or test name, and
+# `_TARGET_PASS_OPTIONS` is formatted to be passed as options to certain passes that
+# expect "arch=<arch> features=<+feature1,...>".
+# More than one target cpu feature is currently unsupported.
+#
+# aarch64:+dotprod ->_ENABLED="TRUE" if the target architecture is aarch64,
+# _TARGET_CPU_FEATURES="+dotprod",
+# _TARGET_CPU_FEATURES_SUFFIX="_dotprod",
+# _TARGET_PASS_OPTIONS="arch=aarch64 features=+dotprod"
+# default -> _ENABLED="TRUE" unconditionally, other output strings are "".
function(process_target_cpu_features _INPUT_TARGET_CPU_FEATURES _ENABLED
- _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX _TARGET_PASS_OPTIONS)
+ set(_TARGET_CPU_FEATURES "" PARENT_SCOPE)
+ set(_TARGET_CPU_FEATURES_SUFFIX "" PARENT_SCOPE)
+ set(_TARGET_PASS_OPTIONS "" PARENT_SCOPE)
if ("${_INPUT_TARGET_CPU_FEATURES}" STREQUAL "default")
set(_ENABLED "TRUE" PARENT_SCOPE)
- set(_TARGET_CPU_FEATURES "" PARENT_SCOPE)
- set(_TARGET_CPU_FEATURES_SUFFIX "" PARENT_SCOPE)
return()
endif()
string(REGEX MATCHALL "[^:]+" _COMPONENTS "${_INPUT_TARGET_CPU_FEATURES}")
@@ -349,8 +356,11 @@
TARGET_CPU_FEATURES should match [a-zA-Z0-9]+ after the initial +. \
Got: ${_TARGET_CPU_FEATURES}.")
endif()
+ # Generate the target cpu features suffix string with underscores ('_')
+ # separating the features.
string(REPLACE "+" "_" _TARGET_CPU_FEATURES_SUFFIX_LOCAL "${_TARGET_CPU_FEATURES}")
set(_TARGET_CPU_FEATURES_SUFFIX "${_TARGET_CPU_FEATURES_SUFFIX_LOCAL}" PARENT_SCOPE)
+ set(_TARGET_PASS_OPTIONS "arch=${_FILTER_ARCH} features=${_TARGET_CPU_FEATURES}" PARENT_SCOPE)
else()
set(_ENABLED "FALSE" PARENT_SCOPE)
endif()
@@ -425,7 +435,8 @@
set(_TARGET_CPU_FEATURES_VARIANTS "default")
endif()
foreach(_TARGET_CPU_FEATURES_LIST_ELEM IN LISTS _TARGET_CPU_FEATURES_VARIANTS)
- process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX _TARGET_PASS_OPTIONS)
+ string(REPLACE "#pass_options_variant#" "${_TARGET_PASS_OPTIONS}" _PROCESSED_OPT_FLAGS "${_RULE_OPT_FLAGS}")
if (NOT _ENABLED)
# The current entry is disabled on the target CPU architecture.
continue()
diff --git a/build_tools/cmake/iree_trace_runner_test.cmake b/build_tools/cmake/iree_trace_runner_test.cmake
index b32fcdf..30f3e2f 100644
--- a/build_tools/cmake/iree_trace_runner_test.cmake
+++ b/build_tools/cmake/iree_trace_runner_test.cmake
@@ -342,7 +342,8 @@
set(_TARGET_CPU_FEATURES_VARIANTS "default")
endif()
foreach(_TARGET_CPU_FEATURES_LIST_ELEM IN LISTS _TARGET_CPU_FEATURES_VARIANTS)
- process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX _TARGET_PASS_OPTIONS)
+ string(REPLACE "#pass_options_variant#" "${_TARGET_PASS_OPTIONS}" _PROCESSED_OPT_FLAGS "${_RULE_OPT_FLAGS}")
if (NOT _ENABLED)
# The current entry is disabled on the target CPU architecture.
continue()
@@ -369,7 +370,7 @@
OPT_TOOL
${_RULE_OPT_TOOL}
OPT_FLAGS
- ${_RULE_OPT_FLAGS}
+ ${_PROCESSED_OPT_FLAGS}
TARGET_CPU_FEATURES
${_TARGET_CPU_FEATURES}
)
diff --git a/iree/compiler/Codegen/BUILD b/iree/compiler/Codegen/BUILD
index f0ce273..ca89cc4 100644
--- a/iree/compiler/Codegen/BUILD
+++ b/iree/compiler/Codegen/BUILD
@@ -36,6 +36,7 @@
":PassesIncGen",
"//iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Utils",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
diff --git a/iree/compiler/Codegen/CMakeLists.txt b/iree/compiler/Codegen/CMakeLists.txt
index ebd9500..7c91cf2 100644
--- a/iree/compiler/Codegen/CMakeLists.txt
+++ b/iree/compiler/Codegen/CMakeLists.txt
@@ -33,6 +33,7 @@
MLIRTransforms
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Utils
PUBLIC
)
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index 6af3522..8330a2a 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -37,7 +37,9 @@
"//iree/compiler/Dialect/Flow/IR:PartitionableLoopsInterface",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//iree/compiler/Dialect/HAL/Utils",
"//iree/compiler/Dialect/Util/IR",
+ "//iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 5ae1f64..d6c95ba 100644
--- a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -72,7 +72,9 @@
iree::compiler::Dialect::Flow::IR::PartitionableLoopsInterface
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
+ iree::compiler::Utils
PUBLIC
)
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
index ae0c662..55e5b30 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -30,7 +31,7 @@
// A flag to switch between inline asm and intrinsics while we develop these two
// parallel paths.
-static llvm::cl::opt<bool> clUseMmt4dUseIntrinsics(
+static llvm::cl::opt<bool> clMmt4dUseIntrinsics(
"iree-codegen-mmt4d-use-intrinsics",
llvm::cl::desc("Whether to use instrinsics when lowering vector contracts "
"generated from mmt4d matmuls (as opposed to inline asm). "
@@ -357,7 +358,9 @@
// just before the generic vector ops lowerings.
CustomKernelsTargetInfo info;
if (succeeded(InferCustomKernelsTargetInfoFromParent(funcOp, info))) {
- info.intrinsics = clUseMmt4dUseIntrinsics;
+ if (clMmt4dUseIntrinsics) {
+ info.add(CustomKernelTargetFeature::Intrinsics);
+ }
RewritePatternSet patterns(context);
populateVectorContractCustomKernelsPatterns(info, patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
index 0881503..6e5738c 100644
--- a/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
@@ -23,53 +24,6 @@
namespace mlir {
namespace iree_compiler {
-LogicalResult InferCustomKernelsTargetInfoFromParent(
- FuncOp entryPointFn, CustomKernelsTargetInfo &target_info) {
- // Set the out-value to defaults early so that early returns produce
- // consistent results and so that we can write simpler code below
- // (for loop OR-ing booleans, assuming initial 'false' value).
- target_info = CustomKernelsTargetInfo();
-
- // Try to find the parent ExecutableVariantOp and its relevant attributes.
- auto variantOp =
- entryPointFn->getParentOfType<IREE::HAL::ExecutableVariantOp>();
- if (!variantOp) {
- return failure();
- }
- IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.target();
- if (!targetAttr) {
- return failure();
- }
- auto config = targetAttr.getConfiguration();
- if (!config) {
- return failure();
- }
- auto tripleAttr = config.getAs<StringAttr>("target_triple");
- if (!tripleAttr) {
- return failure();
- }
- auto cpuFeaturesAttr = config.getAs<StringAttr>("cpu_features");
- if (!cpuFeaturesAttr) {
- return failure();
- }
-
- // Set the out-value target_info fields.
- llvm::Triple triple(tripleAttr.getValue());
- llvm::SmallVector<llvm::StringRef> cpuFeatures;
- cpuFeaturesAttr.getValue().split(cpuFeatures, ',');
- switch (triple.getArch()) {
- case llvm::Triple::ArchType::aarch64:
- target_info.aarch64 = true;
- for (auto f : cpuFeatures) {
- target_info.dotprod |= (f == "+dotprod");
- }
- break;
- default:
- break;
- }
- return success();
-}
-
namespace {
// Returns true if `contractionOp` is of the form
@@ -439,7 +393,7 @@
public:
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
- if (target_info.intrinsics) {
+ if (target_info.has(CustomKernelTargetFeature::Intrinsics)) {
registry.insert<arm_neon::ArmNeonDialect>();
}
}
@@ -447,9 +401,12 @@
if (failed(Pass::initializeOptions(options))) {
return failure();
}
- target_info.aarch64 = aarch64;
- target_info.dotprod = dotprod;
- target_info.intrinsics = intrinsics;
+ if (failed(ParseCustomKernelsTargetInfo(arch, features, target_info))) {
+ return failure();
+ }
+ if (intrinsics) {
+ target_info.add(CustomKernelTargetFeature::Intrinsics);
+ }
return success();
}
void runOnOperation() override {
@@ -471,8 +428,8 @@
void populateVectorContractCustomKernelsPatterns(
const CustomKernelsTargetInfo &target_info, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- if (target_info.aarch64 && target_info.dotprod) {
- if (target_info.intrinsics) {
+ if (target_info.has(CustomKernelTargetFeature::Aarch64Dotprod)) {
+ if (target_info.has(CustomKernelTargetFeature::Intrinsics)) {
patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics>(context);
} else {
patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm>(context);
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir
index db29930..2689781 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod' %s | FileCheck %s
+// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='arch=aarch64 features=+dotprod' %s | FileCheck %s
func @mmt_8x4x8_i8i8i32_aarch64_dotprod_inline_asm(
%lhs: vector<8x4xi8>,
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir
index 9193ca9..360bcd3 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod intrinsics' %s | FileCheck %s
+// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='arch=aarch64 features=+dotprod intrinsics' %s | FileCheck %s
// CHECK-LABEL: @vector_i8i8i32matmul(
// CHECK-SAME: %[[LHS:[^:[:space:]]+]]
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 9da0bba..5dd54f1 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
@@ -182,32 +183,6 @@
// LLVMCPU Codegen specific patterns.
//------------------------------------------------------------------------------
-// Some codegen patterns need to know target CPU information. They can receive
-// such information by means of this struct, which can be populated from either
-// pass options (e.g. in lit tests,
-// -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod')
-// or from global state (see InferCustomKernelsTargetInfoFromGlobals below).
-//
-// It would be interesting to find an opportunity to de-duplicate this with
-// other data structures containing similar information, but a difficulty here
-// is that in the case of lit tests, where we need to populate this from
-// a minimal set of custom boolean options passed to a pass such as
-// -iree-llvmcpu-vector-contract-custom-kernels, we do not have enough
-// information to populate all the other fields of existing, larger data
-// structures. That's the motivation for this custom, minimal struct.
-struct CustomKernelsTargetInfo {
- // Indicates that the target ISA is Aarch64
- bool aarch64 = false;
- // Under aarch64: indicates dot-product extension (SDOT, UDOT)
- bool dotprod = false;
- // Indicates that intrinsics should be used rather than inline asm
- bool intrinsics = false;
-};
-
-// Populate target_info fields from the parent HAL::ExecutableVariantOp.
-LogicalResult InferCustomKernelsTargetInfoFromParent(
- FuncOp entryPointFn, CustomKernelsTargetInfo &target_info);
-
/// Populates `patterns` to convert certain vector.contract ops to special
/// "kernels" written either in SIMD intrinsics or inline assembly.
void populateVectorContractCustomKernelsPatterns(
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 7b90474..f9ebd17 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -171,15 +171,15 @@
let summary = "Enable custom kernels (inline assembly or intrinsics) for some vector.contract ops";
let constructor = "mlir::iree_compiler::createVectorContractCustomKernelsPass()";
let options = [
- Option<"aarch64", "aarch64", "bool",
- /*default=*/"false",
- "Enable aarch64 kernels">,
- Option<"dotprod", "dotprod", "bool",
- /*default=*/"false",
- "Under aarch64, enable kernels that use dotprod instructions">,
+ Option<"arch", "arch", "std::string",
+ /*default=*/"",
+ "Target architecture, e.g. aarch64">,
+ Option<"features", "features", "std::string",
+ /*default=*/"",
+ "Additional CPU feature flags, e.g. +dotprod">,
Option<"intrinsics", "intrinsics", "bool",
/*default=*/"false",
- "Under aarch64, enable kernels that use dotprod instructions">,
+ "Use intrinsics over inline assembly where applicable">,
];
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
index a10e778..08719c5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
@@ -8,6 +8,8 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
+#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
@@ -215,18 +217,34 @@
y.getType().cast<ShapedType>().getDimSize(i);
}
+class Mmt4DTileParams {
+ public:
+ Mmt4DTileParams(int64_t M0, int64_t K0, int64_t N0, std::string comment)
+ : M0(M0), K0(K0), N0(N0), comment(comment) {}
+ std::array<int64_t, 2> lhs() const { return {M0, K0}; }
+ std::array<int64_t, 2> rhs() const { return {K0, N0}; }
+ std::array<int64_t, 2> acc() const { return {M0, N0}; }
+ const std::string &getComment() const { return comment; }
+
+ private:
+ const int64_t M0;
+ const int64_t K0;
+ const int64_t N0;
+ const std::string comment;
+};
+
// Converts linalg.matmul to an equivalent subgraph using linalg.mmt4d.
// Currently, M0, N0, K0 are compile time constants.
// TODO(ataei): Move this pattern to linalg transforms upstream.
class LinalgMatmulOpToLinalgMmt4DOpPattern
: public OpRewritePattern<linalg::MatmulOp> {
public:
- LinalgMatmulOpToLinalgMmt4DOpPattern(MLIRContext *context, int M0, int K0,
- int N0, PatternBenefit benefit = 1)
- : OpRewritePattern<linalg::MatmulOp>(context, benefit),
- M0(M0),
- K0(K0),
- N0(N0) {}
+ LinalgMatmulOpToLinalgMmt4DOpPattern(
+ MLIRContext *context, const CustomKernelsTargetInfo &target_info,
+ bool enable_generic_slow)
+ : OpRewritePattern<linalg::MatmulOp>(context),
+ target_info(target_info),
+ enable_generic_slow(enable_generic_slow) {}
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
PatternRewriter &rewriter) const override {
@@ -249,23 +267,33 @@
return failure();
}
- Value paddedLhs = pad(loc, rewriter, lhs, {M0, K0});
- Value paddedRhs = pad(loc, rewriter, rhs, {K0, N0});
- Value paddedAcc = pad(loc, rewriter, acc, {M0, N0});
+ const auto &maybe_tile_params = chooseTileParams(lhs, rhs, acc);
+ if (!maybe_tile_params) {
+ // No good tiling is known for the given problem shape, and the slow
+ // generic fallback (for tests) is not enabled.
+ return failure();
+ }
+ const Mmt4DTileParams &tile_params = maybe_tile_params.getValue();
- Value lhs4D = expandTo4D(loc, rewriter, paddedLhs, {M0, K0});
- Value rhs4D = expandTo4D(loc, rewriter, paddedRhs, {K0, N0});
- Value acc4D = expandTo4D(loc, rewriter, paddedAcc, {M0, N0});
+ Value paddedLhs = pad(loc, rewriter, lhs, tile_params.lhs());
+ Value paddedRhs = pad(loc, rewriter, rhs, tile_params.rhs());
+ Value paddedAcc = pad(loc, rewriter, acc, tile_params.acc());
+
+ Value lhs4D = expandTo4D(loc, rewriter, paddedLhs, tile_params.lhs());
+ Value rhs4D = expandTo4D(loc, rewriter, paddedRhs, tile_params.rhs());
+ Value acc4D = expandTo4D(loc, rewriter, paddedAcc, tile_params.acc());
Value lhs4DT = transpose(loc, rewriter, lhs4D, {0, 2, 1, 3});
Value rhs4DT = transpose(loc, rewriter, rhs4D, {2, 0, 3, 1});
Value acc4DT = transpose(loc, rewriter, acc4D, {0, 2, 1, 3});
- auto mmt4dResult = rewriter.create<linalg::Mmt4DOp>(
+ auto mmt4d = rewriter.create<linalg::Mmt4DOp>(
loc, acc4DT.getType(), ValueRange{lhs4DT, rhs4DT}, ValueRange{acc4DT});
+ mmt4d->setAttr(StringAttr::get(getContext(), "comment"),
+ StringAttr::get(getContext(), tile_params.getComment()));
Value mmt4dResultTransposed =
- transpose(loc, rewriter, mmt4dResult.getResult(0), {0, 2, 1, 3});
+ transpose(loc, rewriter, mmt4d.getResult(0), {0, 2, 1, 3});
Value paddedResult =
collapseTo2D(loc, rewriter, mmt4dResultTransposed,
@@ -278,11 +306,32 @@
}
private:
- const int M0;
- const int K0;
- const int N0;
+ llvm::Optional<Mmt4DTileParams> chooseTileParams(Value lhs, Value rhs,
+ Value acc) const;
+
+ CustomKernelsTargetInfo target_info;
+ bool enable_generic_slow;
};
+llvm::Optional<Mmt4DTileParams>
+LinalgMatmulOpToLinalgMmt4DOpPattern::chooseTileParams(Value lhs, Value rhs,
+ Value acc) const {
+ Type lhsElemType = lhs.getType().cast<ShapedType>().getElementType();
+ Type rhsElemType = rhs.getType().cast<ShapedType>().getElementType();
+ Type accElemType = acc.getType().cast<ShapedType>().getElementType();
+ if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) &&
+ accElemType.isSignlessInteger(32) &&
+ target_info.has(CustomKernelTargetFeature::Aarch64Dotprod)) {
+ return Mmt4DTileParams(8, 4, 8, "i8*i8->i32, aarch64 +dotprod");
+ }
+ if (enable_generic_slow) {
+ return Mmt4DTileParams(8, 2, 4,
+ "generic tiling parameters, as no known kernel was "
+ "matched for this matmul and target");
+ }
+ return llvm::None;
+}
+
/// Canonicalizes [linalg.init_tensor -> linalg.fill -> linalg.generic] ->
/// [linalg.init_tensor -> linalg.fill] where linalg.generic does only copy e.g
/// a transpose.
@@ -336,23 +385,7 @@
LogicalResult initializeOptions(StringRef options) override {
if (failed(Pass::initializeOptions(options))) return failure();
- auto failureWithMessage = [=](const char *msg) {
- llvm::errs() << "illegal options `" << options << "` for pass `"
- << getArgument() << "`: " << msg << "\n";
- return failure();
- };
- if (M0 == mlir::ShapedType::kDynamicSize ||
- N0 == mlir::ShapedType::kDynamicSize ||
- K0 == mlir::ShapedType::kDynamicSize) {
- return failureWithMessage(
- "currently all three values M0,K0,N0 must be "
- "specified as a fixed size value, not 'dynamic', as the heuristic to "
- "choose these values is not yet implemented.");
- }
- if (M0 == 0 || N0 == 0 || K0 == 0) {
- return failureWithMessage("all three values M0,K0,N0 must be nonzero.");
- }
- return success();
+ return ParseCustomKernelsTargetInfo(arch, features, target_info);
}
void runOnOperation() override {
@@ -360,8 +393,8 @@
// Main pattern.
{
RewritePatternSet patterns(&getContext());
- patterns.insert<LinalgMatmulOpToLinalgMmt4DOpPattern>(context, M0, K0,
- N0);
+ patterns.insert<LinalgMatmulOpToLinalgMmt4DOpPattern>(
+ context, target_info, enable_generic_slow);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
@@ -380,6 +413,9 @@
}
}
}
+
+ private:
+ CustomKernelsTargetInfo target_info;
};
} // namespace
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 4fcb2dd..6f6502e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -76,8 +76,8 @@
// subtensor_insert. This allows lowering the operation into a single kernel.
std::unique_ptr<Pass> createPadTensorToSubTensorInsertPass();
-// Pass to convert a linalg.matmul into linalg.mmt4d given M0, N0 and K0 are
-// compile time constants.
+// Pass to convert a linalg.matmul into linalg.mmt4d given some target ISA
+// information currently passed as pass options.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgMatmulToMmt4DPass();
// Creates a pass to fuse Linalg operations on tensors.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 2e8dc08..7fd855f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -109,12 +109,15 @@
let summary = "Convert linalg.matmul to linalg.mmt4d";
let constructor = "mlir::iree_compiler::IREE::Flow::createConvertLinalgMatmulToMmt4DPass()";
let options = [
- Option<"M0", "M0", "int", /*default=*/"mlir::ShapedType::kDynamicSize",
- "Specifies an explicit M-axis tile size, overriding the default heuristic.">,
- Option<"K0", "K0", "int", /*default=*/"mlir::ShapedType::kDynamicSize",
- "Specifies an explicit K-axis tile size, overriding the default heuristic.">,
- Option<"N0", "N0", "int", /*default=*/"mlir::ShapedType::kDynamicSize",
- "Specifies an explicit N-axis tile size, overriding the default heuristic.">,
+ Option<"arch", "arch", "std::string",
+ /*default=*/"",
+ "Target architecture, e.g. aarch64">,
+ Option<"features", "features", "std::string",
+ /*default=*/"",
+ "Additional CPU feature flags, e.g. +dotprod">,
+ Option<"enable_generic_slow", "enable_generic_slow", "bool",
+ /*default=*/"false",
+ "For tests only. Use mmt4d even for cases that are not expected to compile to efficient code by using some arbitrary generic tile shape.">,
];
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir b/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
index cf1d826..22138ee 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file --iree-flow-convert-linalg-matmul-to-mmt4d='M0=8 K0=2 N0=4' %s | FileCheck %s
+// RUN: iree-opt -split-input-file --iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow %s | FileCheck %s
func @check_mmt4d_f32_static_nopad(%arg0: tensor<24x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<24x32xf32>) -> tensor<24x32xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x8xf32>, tensor<8x32xf32>) outs(%arg2 : tensor<24x32xf32>) -> tensor<24x32xf32>
@@ -38,7 +38,9 @@
// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %arg3 : f32
// CHECK-NEXT: } -> tensor<3x8x8x4xf32>
-// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[LHS4DT]], %[[RHS4DT]] : tensor<3x4x8x2xf32>, tensor<8x4x4x2xf32>) outs(%[[DST4DT]] : tensor<3x8x8x4xf32>) -> tensor<3x8x8x4xf32>
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: {comment = "generic tiling parameters, as no known kernel was matched for this matmul and target"}
+// CHECK-SAME: ins(%[[LHS4DT]], %[[RHS4DT]] : tensor<3x4x8x2xf32>, tensor<8x4x4x2xf32>) outs(%[[DST4DT]] : tensor<3x8x8x4xf32>) -> tensor<3x8x8x4xf32>
// CHECK: %[[MMT4DT_INIT:.+]] = linalg.init_tensor [3, 8, 8, 4] : tensor<3x8x8x4xf32>
// CHECK: %[[MMT4DT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
diff --git a/iree/compiler/Dialect/HAL/Utils/BUILD b/iree/compiler/Dialect/HAL/Utils/BUILD
index 9ea8bb6..a0febfd 100644
--- a/iree/compiler/Dialect/HAL/Utils/BUILD
+++ b/iree/compiler/Dialect/HAL/Utils/BUILD
@@ -12,11 +12,16 @@
cc_library(
name = "Utils",
+ srcs = [
+ "InferCustomKernelsTargetInfoFromParent.cpp",
+ ],
hdrs = [
"DeviceSwitchBuilder.h",
+ "InferCustomKernelsTargetInfoFromParent.h",
],
deps = [
"//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt b/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
index ddb2b88..643b46b 100644
--- a/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
@@ -15,6 +15,9 @@
Utils
HDRS
"DeviceSwitchBuilder.h"
+ "InferCustomKernelsTargetInfoFromParent.h"
+ SRCS
+ "InferCustomKernelsTargetInfoFromParent.cpp"
DEPS
LLVMSupport
MLIRIR
@@ -22,6 +25,7 @@
MLIRSupport
MLIRTransforms
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Utils
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.cpp b/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.cpp
new file mode 100644
index 0000000..accabdd
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.cpp
@@ -0,0 +1,54 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult InferCustomKernelsTargetInfoFromParent(
+ FuncOp entryPointFn, CustomKernelsTargetInfo &target_info) {
+ // Set the out-value to defaults early so that early returns produce
+ // consistent results and so that we can write simpler code below
+ // (for loop OR-ing booleans, assuming initial 'false' value).
+ target_info = CustomKernelsTargetInfo();
+
+ // Try to find the parent ExecutableVariantOp and its relevant attributes.
+ auto variantOp =
+ entryPointFn->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ if (!variantOp) {
+ return failure();
+ }
+ IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.target();
+ if (!targetAttr) {
+ return failure();
+ }
+ auto config = targetAttr.getConfiguration();
+ if (!config) {
+ return failure();
+ }
+ auto tripleAttr = config.getAs<StringAttr>("target_triple");
+ if (!tripleAttr) {
+ return failure();
+ }
+ auto cpuFeaturesAttr = config.getAs<StringAttr>("cpu_features");
+ if (!cpuFeaturesAttr) {
+ return failure();
+ }
+
+ // Exactly the implementation of llvm::Triple::getArchName, skipping all the
+ // parsing work of constructing a llvm::Triple from a string.
+ llvm::StringRef archName(tripleAttr.getValue().split('-').first);
+ llvm::StringRef featuresStr(cpuFeaturesAttr.getValue());
+ return ParseCustomKernelsTargetInfo(archName, featuresStr, target_info);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.h b/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.h
new file mode 100644
index 0000000..e7fe181
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Utils/InferCustomKernelsTargetInfoFromParent.h
@@ -0,0 +1,27 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_INFERCUSTOMKERNELSTARGETINFOFROMPARENT_H_
+#define IREE_COMPILER_DIALECT_HAL_UTILS_INFERCUSTOMKERNELSTARGETINFOFROMPARENT_H_
+
+#include <stdint.h>
+
+#include <cassert>
+
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult InferCustomKernelsTargetInfoFromParent(
+ FuncOp entryPointFn, CustomKernelsTargetInfo &target_info);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_UTILS_INFERCUSTOMKERNELSTARGETINFOFROMPARENT_H_
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
index c1cb107..78a9829 100644
--- a/iree/compiler/Utils/BUILD
+++ b/iree/compiler/Utils/BUILD
@@ -16,6 +16,7 @@
name = "Utils",
srcs = [
"ConversionUtils.cpp",
+ "CustomKernelsTargetInfo.cpp",
"FlatbufferUtils.cpp",
"GraphUtils.cpp",
"ModuleUtils.cpp",
@@ -26,6 +27,7 @@
],
hdrs = [
"ConversionUtils.h",
+ "CustomKernelsTargetInfo.h",
"FlatbufferUtils.h",
"GraphUtils.h",
"IndexSet.h",
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt
index 1bbfa25..eb73468 100644
--- a/iree/compiler/Utils/CMakeLists.txt
+++ b/iree/compiler/Utils/CMakeLists.txt
@@ -15,6 +15,7 @@
Utils
HDRS
"ConversionUtils.h"
+ "CustomKernelsTargetInfo.h"
"FlatbufferUtils.h"
"GraphUtils.h"
"IndexSet.h"
@@ -26,6 +27,7 @@
"TracingUtils.h"
SRCS
"ConversionUtils.cpp"
+ "CustomKernelsTargetInfo.cpp"
"FlatbufferUtils.cpp"
"GraphUtils.cpp"
"ModuleUtils.cpp"
diff --git a/iree/compiler/Utils/CustomKernelsTargetInfo.cpp b/iree/compiler/Utils/CustomKernelsTargetInfo.cpp
new file mode 100644
index 0000000..184fdb8
--- /dev/null
+++ b/iree/compiler/Utils/CustomKernelsTargetInfo.cpp
@@ -0,0 +1,51 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Triple.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult ParseCustomKernelTargetFeaturesForAarch64(
+ const llvm::SmallVector<llvm::StringRef> &features,
+ CustomKernelsTargetInfo &target_info) {
+ for (auto f : features) {
+ if (f == "+dotprod") {
+ target_info.add(CustomKernelTargetFeature::Aarch64Dotprod);
+ } else {
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult ParseCustomKernelsTargetInfo(
+ llvm::StringRef archStr, llvm::StringRef featuresStr,
+ CustomKernelsTargetInfo &target_info) {
+ // Set the out-value to defaults early so that early returns produce
+ // consistent results and so that we can write simpler code below.
+ target_info = CustomKernelsTargetInfo();
+
+ if (archStr.empty()) {
+ return success();
+ }
+
+ llvm::SmallVector<llvm::StringRef> features;
+ featuresStr.split(features, ',');
+
+ if (archStr == "aarch64") {
+ target_info.init(CustomKernelTargetArch::Aarch64);
+ return ParseCustomKernelTargetFeaturesForAarch64(features, target_info);
+ }
+
+ return failure();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/CustomKernelsTargetInfo.h b/iree/compiler/Utils/CustomKernelsTargetInfo.h
new file mode 100644
index 0000000..f456b71
--- /dev/null
+++ b/iree/compiler/Utils/CustomKernelsTargetInfo.h
@@ -0,0 +1,85 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_CUSTOMKERNELTARGETINFO_H_
+#define IREE_COMPILER_UTILS_CUSTOMKERNELTARGETINFO_H_
+
+#include <stdint.h>
+
+#include <cassert>
+
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Enumerates target ISAs that we care about.
+enum class CustomKernelTargetArch { None, Aarch64 };
+
+// Enumerates arch-specific target features that we care about.
+// We explicitly want to stick to the default enumeration values (0, 1, 2, ...,
+// no greater than 63) because this is going to be indexing a uint64 bitfield.
+// Intentionally not reusing bits across architectures to be able to catch
+// most bugs. 64 is enough across all target architectures for now.
+enum class CustomKernelTargetFeature {
+ // Indicates a preference for intrinsics over inline asm. Unlike other bits,
+ // this is generic, not tied to a particular architecture or CPU feature, and
+ // it has to be passed through some custom boolean flag or option, not as
+ // part of the target CPU features.
+ Intrinsics,
+ // Aarch64 features.
+ Aarch64Dotprod,
+};
+
+inline bool isFeatureForArch(CustomKernelTargetFeature feature,
+ CustomKernelTargetArch arch) {
+ switch (feature) {
+ case CustomKernelTargetFeature::Intrinsics:
+ return true;
+ case CustomKernelTargetFeature::Aarch64Dotprod:
+ return arch == CustomKernelTargetArch::Aarch64;
+ }
+ assert(false && "Unhandled CustomKernelTargetFeature value");
+ return false;
+}
+
+// Class used to pass some target information to patterns/passes that need it.
+// The information could come from pass options, e.g.
+// -iree-llvmcpu-vector-contract-custom-kernels='arch=aarch64
+// features=+dotprod intrinsics'
+// or from a parent HAL::ExecutableVariantOp and/or be complemented by a
+// global flag like clMmt4dUseIntrinsics.
+class CustomKernelsTargetInfo {
+ public:
+ void init(CustomKernelTargetArch a) {
+ assert(arch == CustomKernelTargetArch::None);
+ arch = a;
+ }
+ bool has(CustomKernelTargetFeature f) const {
+ if (!isFeatureForArch(f, arch)) {
+ return false;
+ }
+ return features & (1ull << static_cast<int>(f));
+ }
+ void add(CustomKernelTargetFeature f) {
+ assert(isFeatureForArch(f, arch));
+ features |= (1ull << static_cast<int>(f));
+ }
+
+ private:
+ CustomKernelTargetArch arch = CustomKernelTargetArch::None;
+ // Bitfield, with bits indexed by CustomKernelTargetFeature.
+ uint64_t features = 0;
+};
+
+LogicalResult ParseCustomKernelsTargetInfo(
+ llvm::StringRef archStr, llvm::StringRef featuresStr,
+ CustomKernelsTargetInfo &target_info);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_CUSTOMKERNELTARGETINFO_H_
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index c57be78..cc78f00 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -139,16 +139,13 @@
"--shapes=small",
],
opt_flags = [
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#",
],
target_backends_and_drivers = [
("dylib-llvm-aot", "dylib"),
- ("vmvx", "vmvx"),
],
- target_cpu_features_variants = [
- "default",
- "aarch64:+dotprod",
- ],
+ target_cpu_features_variants = ["default"] +
+ (["aarch64:+dotprod"] if lhs_rhs_type == "i8" else []),
trace_runner = "//iree/tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
@@ -163,16 +160,13 @@
"--shapes=large",
],
opt_flags = [
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#",
],
target_backends_and_drivers = [
("dylib-llvm-aot", "dylib"),
- # TODO: enable VMVX. Skipped for now: it's very slow for these large matmul tests.
],
- target_cpu_features_variants = [
- "default",
- "aarch64:+dotprod",
- ],
+ target_cpu_features_variants = ["default"] +
+ (["aarch64:+dotprod"] if lhs_rhs_type == "i8" else []),
trace_runner = "//iree/tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
@@ -180,7 +174,8 @@
]]
# Test intrinsics. No need to run vmvx again, since it isn't affected by this
-# codegen flag.
+# codegen flag. No need to run "large" sizes, since this only differs from other
+# tests in ways that are orthogonal to problem sizes.
[iree_generated_trace_runner_test(
name = "e2e_matmul_mmt4d_%s_intrinsics_%s" % (lhs_rhs_type, size),
compiler_flags = ["--iree-codegen-mmt4d-use-intrinsics"],
@@ -190,20 +185,17 @@
"--shapes=%s" % size,
],
opt_flags = [
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#",
],
target_backends_and_drivers = [
("dylib-llvm-aot", "dylib"),
],
- target_cpu_features_variants = [
- "default",
- "aarch64:+dotprod",
- ],
+ target_cpu_features_variants = ["default"] +
+ (["aarch64:+dotprod"] if lhs_rhs_type == "i8" else []),
trace_runner = "//iree/tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
"f32",
] for size in [
"small",
- "large",
]]
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index d3aedfd..9942708 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -144,12 +144,10 @@
iree_tools_iree-e2e-matmul-test
TARGET_BACKENDS
"dylib-llvm-aot"
- "vmvx"
DRIVERS
"dylib"
- "vmvx"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
"aarch64:+dotprod"
@@ -167,15 +165,12 @@
iree_tools_iree-e2e-matmul-test
TARGET_BACKENDS
"dylib-llvm-aot"
- "vmvx"
DRIVERS
"dylib"
- "vmvx"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
- "aarch64:+dotprod"
)
iree_generated_trace_runner_test(
@@ -193,7 +188,7 @@
DRIVERS
"dylib"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
"aarch64:+dotprod"
@@ -214,10 +209,9 @@
DRIVERS
"dylib"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
- "aarch64:+dotprod"
)
iree_generated_trace_runner_test(
@@ -237,30 +231,7 @@
COMPILER_FLAGS
"--iree-codegen-mmt4d-use-intrinsics"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "aarch64:+dotprod"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_i8_intrinsics_large
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=i8"
- "--shapes=large"
- TRACE_RUNNER
- iree_tools_iree-e2e-matmul-test
- TARGET_BACKENDS
- "dylib-llvm-aot"
- DRIVERS
- "dylib"
- COMPILER_FLAGS
- "--iree-codegen-mmt4d-use-intrinsics"
- OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
"aarch64:+dotprod"
@@ -283,33 +254,9 @@
COMPILER_FLAGS
"--iree-codegen-mmt4d-use-intrinsics"
OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=enable_generic_slow #pass_options_variant#"
TARGET_CPU_FEATURES_VARIANTS
"default"
- "aarch64:+dotprod"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_f32_intrinsics_large
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=large"
- TRACE_RUNNER
- iree_tools_iree-e2e-matmul-test
- TARGET_BACKENDS
- "dylib-llvm-aot"
- DRIVERS
- "dylib"
- COMPILER_FLAGS
- "--iree-codegen-mmt4d-use-intrinsics"
- OPT_FLAGS
- "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "aarch64:+dotprod"
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###