[CPU] Add a specialized pipeline for LinalgExt::AttentionOp. (#16577)
The revision adds a new pipeline for LinalgExt ops. It is an
experimental pipeline, and should eventually get merged into
MultiTilingPipeline.
The new pipeline introduces vector level of tiling to LinalgExt, and
vectorization. Some dimension of attention op is not able to tile at
this moment, so we set all the tile sizes to 1 which avoids huge
vectors. Because the reduction dimension of matmuls is not tiled. Here
is selected IR dump:
https://gist.githubusercontent.com/hanhanW/db4511da681d4932cb81dd68cc98976f/raw/08c3cc42c9d7fb86b769f60dc712fecb9fb10700/dump.mlir
Towards https://github.com/openxla/iree/issues/16421
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 60fe213..d942e33 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -25,6 +25,8 @@
: I32EnumAttrCase<"CPUBufferOpsTileAndVectorize", 5>;
def CPU_DataTiling
: I32EnumAttrCase<"CPUDataTiling", 6>;
+def CPU_LinalgExtTileAndVectorize
+ : I32EnumAttrCase<"CPULinalgExtTileAndVectorize", 7>;
def LLVMGPU_Default
: I32EnumAttrCase<"LLVMGPUDefault", 100>;
@@ -81,7 +83,7 @@
CPU_Default, CPU_DoubleTilingExpert,
CPU_DoubleTilingPeelingExpert, CPU_ConvTileAndDecomposeExpert,
CPU_Mmt4dTilingExpert, CPU_BufferOpsTileAndVectorize,
- CPU_DataTiling,
+ CPU_DataTiling, CPU_LinalgExtTileAndVectorize,
// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index d2fe474..38bfbcf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -1447,6 +1447,21 @@
DispatchLoweringPassPipeline::CPUDataTiling);
}
+static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
+ IREE::LinalgExt::AttentionOp attnOp) {
+ SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
+ attnOp, DistributionHeuristicConfig{});
+ int64_t iterationDomainRank = attnOp.getIterationDomainRank();
+ // There are some dimensions are not tiled. Set vector tile sizes being ones
+ // to avoid huge vectors.
+ // TODO: We should be able to tile other dimensions.
+ SmallVector<int64_t> vecTileSizes(iterationDomainRank, 1);
+ TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPointFn, attnOp, tileSizes,
+ DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
+}
+
/// Sets the lowering configuration for dispatch region for linalg_ext.fft
/// root op.
static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
@@ -2032,8 +2047,9 @@
return setRootConfig(entryPointFn, op, LinalgOpInfo(op),
targetMLTransInfo);
})
- .Case<IREE::LinalgExt::FftOp, tensor::PackOp, tensor::PadOp,
- tensor::UnPackOp, linalg::Mmt4DOp, linalg::BatchMmt4DOp>(
+ .Case<IREE::LinalgExt::AttentionOp, IREE::LinalgExt::FftOp,
+ tensor::PackOp, tensor::PadOp, tensor::UnPackOp, linalg::Mmt4DOp,
+ linalg::BatchMmt4DOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNchwFchwOp,
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 0231589..7e3c1fa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -185,6 +185,12 @@
addCPUDataTilingPipeline(pipeline, tilingConfig, enableVectorMasking);
break;
}
+ case IREE::Codegen::DispatchLoweringPassPipeline::
+ CPULinalgExtTileAndVectorize: {
+ TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
+ addCPULinalgExtTileAndVectorizePipeline(pipeline, tilingConfig);
+ break;
+ }
// Transform-dialect pipelines.
case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 777e16b..959ed69 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -100,9 +100,6 @@
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
- // TODO(#16421): Disable decomposition due to failure in bufferization.
- // nestedModulePM.addNestedPass<func::FuncOp>(
- // IREE::LinalgExt::createTileAndDecomposeAttentionPass());
nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createTileAndDecomposeWinogradTransformPass());
}
@@ -577,6 +574,35 @@
}
}
+void addCPULinalgExtTileAndVectorizePipeline(OpPassManager &passManager,
+ TilingConfig &tilingConfig) {
+ addTileAndDistributePasses(passManager);
+ OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
+ // TODO: Should only apply decomposition here?
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ IREE::LinalgExt::createTileAndDecomposeAttentionPass());
+
+ {
+ GenericVectorizationPassOptions options;
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createGenericVectorizationPass(options));
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createOptimizeTensorInsertExtractSlicesPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
+ }
+
+ addCPUBufferizePasses(nestedModulePM);
+
+ {
+ LLVMCPUVectorLoweringPassOptions options;
+ options.splitVectorTransfersTo = "linalg-copy";
+ buildLLVMCPUVectorLoweringPipeline(nestedModulePM, options);
+ }
+}
+
void addCPUDefaultPassPipeline(OpPassManager &passManager) {
addTileAndDistributePasses(passManager);
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index d9b7efa..ca2a063 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -144,6 +144,9 @@
TilingConfig &tilingConfig,
bool enableVectorMasking);
+void addCPULinalgExtTileAndVectorizePipeline(OpPassManager &passManager,
+ TilingConfig &tilingConfig);
+
/// Populates the passes to lower to scalars operations for linalg based
/// code-generation. This pipeline does not vectorize, but instead just
/// converts to memrefs
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index d6e3a32..e00542e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -2318,3 +2318,44 @@
// CHECK: tensor.pad {{.+}} {
// CHECK: tensor.yield
// CHECK-NEXT: } {lowering_config = #[[CONFIG]]}
+
+// -----
+
+hal.executable private @attention {
+ hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {
+ cpu = "generic", cpu_features = "",
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 64 : index, target_triple = "x86_64-none-elf"}>) {
+ hal.executable.export public @attention ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>]} {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @attention() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
+ %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
+ %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
+ %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
+ %7 = tensor.empty() : tensor<20x4096x64xf16>
+ %8 = iree_linalg_ext.attention
+ ins(%4, %5, %6 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>)
+ outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64], [1, 1]]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
+// CHECK: hal.executable.export public @attention
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: func.func @attention()
+// CHECK: iree_linalg_ext.attention
+// CHECK-SAME: {lowering_config = #[[CONFIG]]}
diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel
index 8f2aae0..9c1b67a 100644
--- a/tests/e2e/linalg_ext_ops/BUILD.bazel
+++ b/tests/e2e/linalg_ext_ops/BUILD.bazel
@@ -26,6 +26,7 @@
],
include = ["*.mlir"],
exclude = [
+ "attention.mlir",
"winograd_input.mlir",
"winograd_output.mlir",
],
@@ -91,6 +92,7 @@
srcs = enforce_glob(
# keep sorted
[
+ "attention.mlir",
"reverse.mlir",
"scan.mlir",
"scatter.mlir",
@@ -125,6 +127,7 @@
],
include = ["*.mlir"],
exclude = [
+ "attention.mlir",
"winograd_input.mlir",
"winograd_output.mlir",
],
@@ -147,6 +150,7 @@
],
include = ["*.mlir"],
exclude = [
+ "attention.mlir",
"reverse.mlir", #TODO(#12415): disabled due to miscompilation on Pixel 6.
# TODO(antiagainst): scan fails on Adreno GPUs due to driver bug.
# Re-enable this once we have new devices with up-to-date drivers.
diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt
index 675ad5e..c83cec3 100644
--- a/tests/e2e/linalg_ext_ops/CMakeLists.txt
+++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt
@@ -76,6 +76,7 @@
NAME
check_llvm-cpu_local-task
SRCS
+ "attention.mlir"
"reverse.mlir"
"scan.mlir"
"scatter.mlir"
diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir
new file mode 100644
index 0000000..cbb2ded
--- /dev/null
+++ b/tests/e2e/linalg_ext_ops/attention.mlir
@@ -0,0 +1,13 @@
+func.func @attention() {
+ %init = tensor.empty() : tensor<1x4x4xf32>
+ %query = util.unfoldable_constant dense<1.0> : tensor<1x4x4xf32>
+ %key = util.unfoldable_constant dense<0.5> : tensor<1x4x4xf32>
+ %value = util.unfoldable_constant dense<2.0> : tensor<1x4x4xf32>
+ %1 = iree_linalg_ext.attention ins(%query, %key, %value : tensor<1x4x4xf32>,
+ tensor<1x4x4xf32>, tensor<1x4x4xf32>) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+ check.expect_almost_eq_const(
+ %1,
+ dense<[[[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]]> : tensor<1x4x4xf32>
+ ) : tensor<1x4x4xf32>
+ return
+}