[CPU][DT] SVE: adjust tile sizes for scalable pack ops (#22924)
Adjust tile sizes for scalable `linalg.pack` ops.
---------
Signed-off-by: Ege Beysel <beyselege@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index fb0d273..7dbb37b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -2109,6 +2109,13 @@
int srcRank = op.getSourceRank();
SmallVector<int64_t> innerTiles = op.getStaticTiles();
+ // Try to infer scalable tile sizes. This is a no-op in case of static inner
+ // tiles or if dynamic tile sizes are found, but scalable tile sizes cannot be
+ // inferred.
+ if (auto sizesAndScalableFlags =
+ getScalableTileSizesAndFlags(op.getMixedTiles())) {
+ innerTiles = sizesAndScalableFlags->first;
+ }
ArrayRef<int64_t> dimPos = op.getInnerDimsPos();
int64_t vectorSize = getVectorSize(entryPointFn, op.getSourceType());
@@ -3069,20 +3076,34 @@
///
/// Steps:
/// 1. Divide the tile sizes of inner dimensions by the corresponding inner
-/// tile factors (ignores dynamic sizes).
+/// tile factors. Handles static and scalable sizes but ignores dynamic
+/// sizes.
/// 2. Apply the outer dimension permutation, if present.
static void scaleAndPermutateTilingForPackOp(linalg::PackOp packOp,
SmallVector<int64_t> &tileSizes,
SmallVector<bool> &scalableFlags) {
- ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
+ SmallVector<int64_t> innerTiles(packOp.getStaticInnerTiles());
+ SmallVector<bool> innerTileScalableFlags(innerTiles.size(), false);
+ // Infer scalable tile sizes and flags if present.
+ if (auto sizesAndScalableFlags =
+ getScalableTileSizesAndFlags(packOp.getMixedTiles())) {
+ innerTiles = sizesAndScalableFlags->first;
+ innerTileScalableFlags = sizesAndScalableFlags->second;
+ }
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
// First scale tile sizes by dividing by the inner tile sizes.
- for (auto [pos, size] : llvm::zip_equal(innerDimPos, innerTiles)) {
+ for (auto [pos, size, scalable] :
+ llvm::zip_equal(innerDimPos, innerTiles, innerTileScalableFlags)) {
+ // Ignore non-scalable dynamic sizes.
if (ShapedType::isDynamic(size)) {
continue;
}
tileSizes[pos] /= size;
+ // Division by vscale by setting the scalable flag to false.
+ if (scalable) {
+ scalableFlags[pos] = false;
+ }
}
// Then apply dimension permutation if present.
if (!outerDimsPerm.empty()) {
@@ -3099,12 +3120,20 @@
/// 1. Undo the outer dimension permutation, if present, by applying the
/// inverted permutation.
/// 2. Multiply the inner dimension tile sizes by the corresponding inner
-/// tile factors (ignores dynamic sizes).
+/// tile factors. Handles static and scalable tile sizes but ignores dynamic
+/// sizes.
static void
undoScaleAndPermutateTilingForPackOp(linalg::PackOp packOp,
SmallVector<int64_t> &tileSizes,
SmallVector<bool> &scalableFlags) {
- ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
+ SmallVector<int64_t> innerTiles(packOp.getStaticInnerTiles());
+ SmallVector<bool> innerTileScalableFlags(innerTiles.size(), false);
+ // Infer scalable tile sizes and flags if present.
+ if (auto sizesAndScalableFlags =
+ getScalableTileSizesAndFlags(packOp.getMixedTiles())) {
+ innerTiles = sizesAndScalableFlags->first;
+ innerTileScalableFlags = sizesAndScalableFlags->second;
+ }
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
// First undo dimension permutation if present.
@@ -3113,12 +3142,16 @@
applyPermutationToVector(tileSizes, invertedPerm);
applyPermutationToVector(scalableFlags, invertedPerm);
}
- // Then unscale tile sizes by multiplying the inner tile sizes.
- for (auto [pos, size] : llvm::zip_equal(innerDimPos, innerTiles)) {
+ // Then unscale tile sizes by multiplying the inner tile sizes and setting the
+ // corresponding scalable flags to true.
+ for (auto [pos, size, scalable] :
+ llvm::zip_equal(innerDimPos, innerTiles, innerTileScalableFlags)) {
+ // Ignore non-scalable dynamic inner tile sizes.
if (ShapedType::isDynamic(size)) {
continue;
}
tileSizes[pos] *= size;
+ scalableFlags[pos] = scalableFlags[pos] || scalable;
}
}
@@ -3205,7 +3238,7 @@
/// As a result, the Pack op expects its producer (potentially the root op) to
/// use tile sizes `[1, 16]` for those two dimensions, enabling tile-and-fuse
/// optimizations.
- SmallVector<int64_t>
+ SizesAndScalableFlags
getVecTileSizesForNonRootPackOp(mlir::FunctionOpInterface entryPointFn,
linalg::PackOp packOp);
@@ -3334,7 +3367,7 @@
continue;
}
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
- nonRootOpVecTileSizes[op] =
+ std::tie(nonRootOpVecTileSizes[op], nonRootOpScalableFlags[op]) =
getVecTileSizesForNonRootPackOp(entryPointFn, packOp);
} else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
std::tie(nonRootOpVecTileSizes[op], nonRootOpScalableFlags[op]) =
@@ -3383,10 +3416,13 @@
for (auto &[op, vecTileSize] : nonRootOpVecTileSizes) {
if (isa<linalg::PackOp>(op)) {
// For pack op, align the distribution tile size and overwrite the
- // vector parallel tile size.
+ // vector parallel tile size and scalable flag.
adjust(op, vecTileSize, IREE::CPU::TilingLevel::DistributionTiles, align);
adjust(op, vecTileSize, IREE::CPU::TilingLevel::VectorCommonParallelTiles,
overwrite);
+ adjustScalableFlags(op, nonRootOpScalableFlags.lookup(op),
+ IREE::CPU::TilingLevel::VectorCommonParallelTiles,
+ overwrite);
} else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
// For unpack op, just overwrite the vector parallel tile size and the
// scalable flag. However, dimension tracking is expected be broken in the
@@ -3470,6 +3506,12 @@
// Only set the tile size if it hasn't been assigned yet.
if (tile == 0 && size > 0) {
tile = size;
+ auto it = nonRootOpScalableFlags.find(op);
+ if (it != nonRootOpScalableFlags.end() && pos < it->second.size()) {
+ globalScalableTileFlags
+ [IREE::CPU::TilingLevel::VectorCommonParallelTiles]
+ [globalDimIdx] = it->second[pos];
+ }
}
}
}
@@ -3592,7 +3634,7 @@
}
}
-SmallVector<int64_t>
+SizesAndScalableFlags
MultiLoweringConfigGenerator::getVecTileSizesForNonRootPackOp(
mlir::FunctionOpInterface entryPointFn, linalg::PackOp packOp) {
SmallVector<int64_t> vecTileSizes =
@@ -3600,9 +3642,9 @@
SmallVector<bool> scalableFlags(vecTileSizes.size(), false);
// Invert the Pack op's `outer_dims_perm` on `vecTileSizes` and
// `scalableFlags`, then multiply `vecTileSizes` by the Pack op's
- // `inner_tiles`.
+ // `inner_tiles` and set the corresponding `scalableFlags`.
undoScaleAndPermutateTilingForPackOp(packOp, vecTileSizes, scalableFlags);
- return vecTileSizes;
+ return {vecTileSizes, scalableFlags};
}
SizesAndScalableFlags
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
index c304ff2..2fe2788 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
@@ -240,6 +240,172 @@
// -----
#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_pack = affine_map<()[s0] -> (48 ceildiv s0)>
+func.func @pack(%arg0: tensor<20x48xf32>) -> tensor<2x?x16x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ %outer1 = affine.apply #map_pack()[%c16_vscale]
+ %empty = tensor.empty(%outer1, %c16_vscale) : tensor<2x?x16x?xf32>
+ %pack = linalg.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, %c16_vscale] into %empty : tensor<20x48xf32> -> tensor<2x?x16x?xf32>
+ return %pack : tensor<2x?x16x?xf32>
+}
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_cpu.lowering_config<distribution = [1, 4], vector_common_parallel = [1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDataTiling>
+// CHECK: func.func @pack(
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: linalg.pack
+// CHECK-SAME: lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_elem_pack = affine_map<()[s0] -> (384 ceildiv s0)>
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack(%arg0: tensor<128x384xf32>) -> tensor<16x?x8x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+ %empty = tensor.empty() : tensor<128x384xf32>
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ %outer1 = affine.apply #map_elem_pack()[%c16_vscale]
+ %filled = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<128x384xf32>) outs(%empty : tensor<128x384xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %sum = arith.addf %in, %in : f32
+ linalg.yield %sum : f32
+ } -> tensor<128x384xf32>
+ %dest = tensor.empty(%outer1, %c16_vscale) : tensor<16x?x8x?xf32>
+ %pack = linalg.pack %filled inner_dims_pos = [0, 1] inner_tiles = [8, %c16_vscale] into %dest : tensor<128x384xf32> -> tensor<16x?x8x?xf32>
+ return %pack : tensor<16x?x8x?xf32>
+}
+// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<distribution = [64, 64], vector_common_parallel = [8, [16]]>
+// CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+// CHECK: func.func @elem_pack(
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG1]]
+// CHECK: linalg.pack
+// CHECK-SAME: lowering_config = #[[CONFIG2]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_rb0 = affine_map<(d0, d1) -> (d0, d1)>
+#map_rb1 = affine_map<(d0, d1) -> (d0)>
+#map_rb_outer = affine_map<()[s0] -> (1024 ceildiv s0)>
+func.func @reduction_broadcast_pack(%arg0: tensor<384x1024xf32>, %arg1: tensor<384xf32>) -> tensor<48x?x8x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+ %cst = arith.constant 0.0 : f32
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ %outer1 = affine.apply #map_rb_outer()[%c16_vscale]
+ %empty0 = tensor.empty() : tensor<384xf32>
+ %empty1 = tensor.empty() : tensor<384x1024xf32>
+ %empty2 = tensor.empty(%outer1, %c16_vscale) : tensor<48x?x8x?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty0 : tensor<384xf32>) -> tensor<384xf32>
+ %generic0 = linalg.generic {indexing_maps = [#map_rb0, #map_rb1, #map_rb1], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<384x1024xf32>, tensor<384xf32>) outs(%fill : tensor<384xf32>) {
+ ^bb0(%in: f32, %in2: f32, %out: f32):
+ %diff = arith.subf %in, %in2 : f32
+ %mul = arith.mulf %diff, %diff : f32
+ %res = arith.addf %out, %mul : f32
+ linalg.yield %res : f32
+ } -> tensor<384xf32>
+ %generic1 = linalg.generic {indexing_maps = [#map_rb1, #map_rb0], iterator_types = ["parallel", "parallel"]} ins(%generic0 : tensor<384xf32>) outs(%empty1 : tensor<384x1024xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<384x1024xf32>
+ %pack = linalg.pack %generic1 inner_dims_pos = [0, 1] inner_tiles = [8, %c16_vscale] into %empty2 : tensor<384x1024xf32> -> tensor<48x?x8x?xf32>
+ return %pack : tensor<48x?x8x?xf32>
+}
+// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [8]>
+// CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<distribution = [32, 0], vector_common_parallel = [8, 0], vector_reduction = [0, 4]>
+// CHECK-DAG: #[[CONFIG3:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [8, 0], vector_inner_parallel = [0, [16]]>
+// CHECK-DAG: #[[CONFIG4:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 0], vector_inner_parallel = [0, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+// CHECK: func.func @reduction_broadcast_pack(
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: linalg.fill
+// CHECK-SAME: lowering_config = #[[CONFIG1]]
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG2]]
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG3]]
+// CHECK: linalg.pack
+// CHECK-SAME: lowering_config = #[[CONFIG4]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "generic", cpu_features = "+sve", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-elf", ukernels = false}>
+#map_tp0 = affine_map<(d0, d1) -> (d1, d0)>
+#map_tp1 = affine_map<(d0, d1) -> (d0, d1)>
+#map_tp_outer = affine_map<()[s0] -> (30522 ceildiv s0)>
+func.func @transpose_pack(%arg0: tensor<30522x768xf32>) -> tensor<?x96x?x8xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ %outer_dynamic = affine.apply #map_tp_outer()[%c16_vscale]
+ %empty0 = tensor.empty() : tensor<768x30522xf32>
+ %0 = linalg.generic {indexing_maps = [#map_tp0, #map_tp1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<30522x768xf32>) outs(%empty0 : tensor<768x30522xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<768x30522xf32>
+ %empty1 = tensor.empty(%outer_dynamic, %c16_vscale) : tensor<?x96x?x8xf32>
+ %pack = linalg.pack %0 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16_vscale, 8] into %empty1 : tensor<768x30522xf32> -> tensor<?x96x?x8xf32>
+ return %pack : tensor<?x96x?x8xf32>
+}
+// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<distribution = [64, 64], vector_common_parallel = [8, [16]]>
+// CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+// CHECK: func.func @transpose_pack(
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG1]]
+// CHECK: linalg.pack
+// CHECK-SAME: lowering_config = #[[CONFIG2]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "generic", cpu_features = "+sve", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-elf", ukernels = false}>
+#map_rp0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map_rp1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_rp_outer = affine_map<()[s0] -> (1024 ceildiv s0)>
+func.func @reduction_pack(%arg0: tensor<384x1024x32xf32>, %arg1: tensor<384x1024xf32>) -> tensor<?x24x16x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+ %cst = arith.constant -0.000000e+00 : f32
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ %outer0 = affine.apply #map_rp_outer()[%c16_vscale]
+ %empty0 = tensor.empty() : tensor<384x1024xf32>
+ %empty2 = tensor.empty(%outer0, %c16_vscale) : tensor<?x24x16x?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty0 : tensor<384x1024xf32>) -> tensor<384x1024xf32>
+ %generic = linalg.generic {indexing_maps = [#map_rp0, #map_rp1, #map_rp1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<384x1024x32xf32>, tensor<384x1024xf32>) outs(%fill : tensor<384x1024xf32>) {
+ ^bb0(%in: f32, %in0: f32, %out: f32):
+ %sub = arith.subf %in, %in0 : f32
+ %mul = arith.mulf %sub, %sub : f32
+ %add = arith.addf %out, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<384x1024xf32>
+ %pack = linalg.pack %generic outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, %c16_vscale] into %empty2 : tensor<384x1024xf32> -> tensor<?x24x16x?xf32>
+ return %pack : tensor<?x24x16x?xf32>
+}
+// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [16, [16]]>
+// CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<distribution = [32, 32, 0], vector_common_parallel = [16, [16], 0], vector_reduction = [0, 0, 4]>
+// CHECK-DAG: #[[CONFIG3:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+// CHECK: func.func @reduction_pack(
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: linalg.fill
+// CHECK-SAME: lowering_config = #[[CONFIG1]]
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG2]]
+// CHECK: linalg.pack
+// CHECK-SAME: lowering_config = #[[CONFIG3]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<()[s0] -> (10240 ceildiv s0)>
func.func @mmt4d_generic_unpack_pack(%arg0: tensor<5x4096x16x1xf16>, %arg1: tensor<?x4096x?x1xf16>) -> tensor<5x10240x16x1xf16> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {