[CPU][DT] SVE: adjust tile sizes for scalable pack ops (#22924)

Adjust tile sizes for scalable `linalg.pack` ops.

---------

Signed-off-by: Ege Beysel <beyselege@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index fb0d273..7dbb37b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -2109,6 +2109,13 @@
 
   int srcRank = op.getSourceRank();
   SmallVector<int64_t> innerTiles = op.getStaticTiles();
+  // Try to infer scalable tile sizes. This is a no-op in case of static inner
+  // tiles or if dynamic tile sizes are found, but scalable tile sizes cannot be
+  // inferred.
+  if (auto sizesAndScalableFlags =
+          getScalableTileSizesAndFlags(op.getMixedTiles())) {
+    innerTiles = sizesAndScalableFlags->first;
+  }
   ArrayRef<int64_t> dimPos = op.getInnerDimsPos();
   int64_t vectorSize = getVectorSize(entryPointFn, op.getSourceType());
 
@@ -3069,20 +3076,34 @@
 ///
 /// Steps:
 /// 1. Divide the tile sizes of inner dimensions by the corresponding inner
-///    tile factors (ignores dynamic sizes).
+///    tile factors. Handles static and scalable sizes but ignores dynamic
+///    sizes.
 /// 2. Apply the outer dimension permutation, if present.
 static void scaleAndPermutateTilingForPackOp(linalg::PackOp packOp,
                                              SmallVector<int64_t> &tileSizes,
                                              SmallVector<bool> &scalableFlags) {
-  ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
+  SmallVector<int64_t> innerTiles(packOp.getStaticInnerTiles());
+  SmallVector<bool> innerTileScalableFlags(innerTiles.size(), false);
+  // Infer scalable tile sizes and flags if present.
+  if (auto sizesAndScalableFlags =
+          getScalableTileSizesAndFlags(packOp.getMixedTiles())) {
+    innerTiles = sizesAndScalableFlags->first;
+    innerTileScalableFlags = sizesAndScalableFlags->second;
+  }
   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
   // First scale tile sizes by dividing by the inner tile sizes.
-  for (auto [pos, size] : llvm::zip_equal(innerDimPos, innerTiles)) {
+  for (auto [pos, size, scalable] :
+       llvm::zip_equal(innerDimPos, innerTiles, innerTileScalableFlags)) {
+    // Ignore non-scalable dynamic sizes.
     if (ShapedType::isDynamic(size)) {
       continue;
     }
     tileSizes[pos] /= size;
+    // Division by vscale by setting the scalable flag to false.
+    if (scalable) {
+      scalableFlags[pos] = false;
+    }
   }
   // Then apply dimension permutation if present.
   if (!outerDimsPerm.empty()) {
@@ -3099,12 +3120,20 @@
 /// 1. Undo the outer dimension permutation, if present, by applying the
 ///    inverted permutation.
 /// 2. Multiply the inner dimension tile sizes by the corresponding inner
-///    tile factors (ignores dynamic sizes).
+///    tile factors. Handles static and scalable tile sizes but ignores dynamic
+///    sizes.
 static void
 undoScaleAndPermutateTilingForPackOp(linalg::PackOp packOp,
                                      SmallVector<int64_t> &tileSizes,
                                      SmallVector<bool> &scalableFlags) {
-  ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
+  SmallVector<int64_t> innerTiles(packOp.getStaticInnerTiles());
+  SmallVector<bool> innerTileScalableFlags(innerTiles.size(), false);
+  // Infer scalable tile sizes and flags if present.
+  if (auto sizesAndScalableFlags =
+          getScalableTileSizesAndFlags(packOp.getMixedTiles())) {
+    innerTiles = sizesAndScalableFlags->first;
+    innerTileScalableFlags = sizesAndScalableFlags->second;
+  }
   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
   // First undo dimension permutation if present.
@@ -3113,12 +3142,16 @@
     applyPermutationToVector(tileSizes, invertedPerm);
     applyPermutationToVector(scalableFlags, invertedPerm);
   }
-  // Then unscale tile sizes by multiplying the inner tile sizes.
-  for (auto [pos, size] : llvm::zip_equal(innerDimPos, innerTiles)) {
+  // Then unscale tile sizes by multiplying the inner tile sizes and setting the
+  // corresponding scalable flags to true.
+  for (auto [pos, size, scalable] :
+       llvm::zip_equal(innerDimPos, innerTiles, innerTileScalableFlags)) {
+    // Ignore non-scalable dynamic inner tile sizes.
     if (ShapedType::isDynamic(size)) {
       continue;
     }
     tileSizes[pos] *= size;
+    scalableFlags[pos] = scalableFlags[pos] || scalable;
   }
 }
 
@@ -3205,7 +3238,7 @@
   /// As a result, the Pack op expects its producer (potentially the root op) to
   /// use tile sizes `[1, 16]` for those two dimensions, enabling tile-and-fuse
   /// optimizations.
-  SmallVector<int64_t>
+  SizesAndScalableFlags
   getVecTileSizesForNonRootPackOp(mlir::FunctionOpInterface entryPointFn,
                                   linalg::PackOp packOp);
 
@@ -3334,7 +3367,7 @@
       continue;
     }
     if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
-      nonRootOpVecTileSizes[op] =
+      std::tie(nonRootOpVecTileSizes[op], nonRootOpScalableFlags[op]) =
           getVecTileSizesForNonRootPackOp(entryPointFn, packOp);
     } else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
       std::tie(nonRootOpVecTileSizes[op], nonRootOpScalableFlags[op]) =
@@ -3383,10 +3416,13 @@
   for (auto &[op, vecTileSize] : nonRootOpVecTileSizes) {
     if (isa<linalg::PackOp>(op)) {
       // For pack op, align the distribution tile size and overwrite the
-      // vector parallel tile size.
+      // vector parallel tile size and scalable flag.
       adjust(op, vecTileSize, IREE::CPU::TilingLevel::DistributionTiles, align);
       adjust(op, vecTileSize, IREE::CPU::TilingLevel::VectorCommonParallelTiles,
              overwrite);
+      adjustScalableFlags(op, nonRootOpScalableFlags.lookup(op),
+                          IREE::CPU::TilingLevel::VectorCommonParallelTiles,
+                          overwrite);
     } else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
       // For unpack op, just overwrite the vector parallel tile size and the
       // scalable flag. However, dimension tracking is expected be broken in the
@@ -3470,6 +3506,12 @@
       // Only set the tile size if it hasn't been assigned yet.
       if (tile == 0 && size > 0) {
         tile = size;
+        auto it = nonRootOpScalableFlags.find(op);
+        if (it != nonRootOpScalableFlags.end() && pos < it->second.size()) {
+          globalScalableTileFlags
+              [IREE::CPU::TilingLevel::VectorCommonParallelTiles]
+              [globalDimIdx] = it->second[pos];
+        }
       }
     }
   }
@@ -3592,7 +3634,7 @@
   }
 }
 
-SmallVector<int64_t>
+SizesAndScalableFlags
 MultiLoweringConfigGenerator::getVecTileSizesForNonRootPackOp(
     mlir::FunctionOpInterface entryPointFn, linalg::PackOp packOp) {
   SmallVector<int64_t> vecTileSizes =
@@ -3600,9 +3642,9 @@
   SmallVector<bool> scalableFlags(vecTileSizes.size(), false);
   // Invert the Pack op's `outer_dims_perm` on `vecTileSizes` and
   // `scalableFlags`, then multiply `vecTileSizes` by the Pack op's
-  // `inner_tiles`.
+  // `inner_tiles` and set the corresponding `scalableFlags`.
   undoScaleAndPermutateTilingForPackOp(packOp, vecTileSizes, scalableFlags);
-  return vecTileSizes;
+  return {vecTileSizes, scalableFlags};
 }
 
 SizesAndScalableFlags
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
index c304ff2..2fe2788 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir
@@ -240,6 +240,172 @@
 // -----
 
 #executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_pack = affine_map<()[s0] -> (48 ceildiv s0)>
+func.func @pack(%arg0: tensor<20x48xf32>) -> tensor<2x?x16x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %c16_vscale = arith.muli %vscale, %c16 : index
+  %outer1 = affine.apply #map_pack()[%c16_vscale]
+  %empty = tensor.empty(%outer1, %c16_vscale) : tensor<2x?x16x?xf32>
+  %pack = linalg.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, %c16_vscale] into %empty : tensor<20x48xf32> -> tensor<2x?x16x?xf32>
+  return %pack : tensor<2x?x16x?xf32>
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_cpu.lowering_config<distribution = [1, 4], vector_common_parallel = [1, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDataTiling>
+//      CHECK: func.func @pack(
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.pack
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_elem_pack = affine_map<()[s0] -> (384 ceildiv s0)>
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack(%arg0: tensor<128x384xf32>) -> tensor<16x?x8x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+  %empty = tensor.empty() : tensor<128x384xf32>
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %c16_vscale = arith.muli %vscale, %c16 : index
+  %outer1 = affine.apply #map_elem_pack()[%c16_vscale]
+  %filled = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<128x384xf32>) outs(%empty : tensor<128x384xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %sum = arith.addf %in, %in : f32
+    linalg.yield %sum : f32
+  } -> tensor<128x384xf32>
+  %dest = tensor.empty(%outer1, %c16_vscale) : tensor<16x?x8x?xf32>
+  %pack = linalg.pack %filled inner_dims_pos = [0, 1] inner_tiles = [8, %c16_vscale] into %dest : tensor<128x384xf32> -> tensor<16x?x8x?xf32>
+  return %pack : tensor<16x?x8x?xf32>
+}
+//  CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<distribution = [64, 64], vector_common_parallel = [8, [16]]>
+//  CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+//      CHECK: func.func @elem_pack(
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       lowering_config = #[[CONFIG1]]
+//      CHECK:   linalg.pack
+// CHECK-SAME:       lowering_config = #[[CONFIG2]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
+#map_rb0 = affine_map<(d0, d1) -> (d0, d1)>
+#map_rb1 = affine_map<(d0, d1) -> (d0)>
+#map_rb_outer = affine_map<()[s0] -> (1024 ceildiv s0)>
+func.func @reduction_broadcast_pack(%arg0: tensor<384x1024xf32>, %arg1: tensor<384xf32>) -> tensor<48x?x8x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+  %cst = arith.constant 0.0 : f32
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %c16_vscale = arith.muli %vscale, %c16 : index
+  %outer1 = affine.apply #map_rb_outer()[%c16_vscale]
+  %empty0 = tensor.empty() : tensor<384xf32>
+  %empty1 = tensor.empty() : tensor<384x1024xf32>
+  %empty2 = tensor.empty(%outer1, %c16_vscale) : tensor<48x?x8x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty0 : tensor<384xf32>) -> tensor<384xf32>
+  %generic0 = linalg.generic {indexing_maps = [#map_rb0, #map_rb1, #map_rb1], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<384x1024xf32>, tensor<384xf32>) outs(%fill : tensor<384xf32>) {
+  ^bb0(%in: f32, %in2: f32, %out: f32):
+    %diff = arith.subf %in, %in2 : f32
+    %mul = arith.mulf %diff, %diff : f32
+    %res = arith.addf %out, %mul : f32
+    linalg.yield %res : f32
+  } -> tensor<384xf32>
+  %generic1 = linalg.generic {indexing_maps = [#map_rb1, #map_rb0], iterator_types = ["parallel", "parallel"]} ins(%generic0 : tensor<384xf32>) outs(%empty1 : tensor<384x1024xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<384x1024xf32>
+  %pack = linalg.pack %generic1 inner_dims_pos = [0, 1] inner_tiles = [8, %c16_vscale] into %empty2 : tensor<384x1024xf32> -> tensor<48x?x8x?xf32>
+  return %pack : tensor<48x?x8x?xf32>
+}
+//  CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [8]>
+//  CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<distribution = [32, 0], vector_common_parallel = [8, 0], vector_reduction = [0, 4]>
+//  CHECK-DAG: #[[CONFIG3:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [8, 0], vector_inner_parallel = [0, [16]]>
+//  CHECK-DAG: #[[CONFIG4:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 0], vector_inner_parallel = [0, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+//      CHECK: func.func @reduction_broadcast_pack(
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.fill
+// CHECK-SAME:       lowering_config = #[[CONFIG1]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       lowering_config = #[[CONFIG2]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       lowering_config = #[[CONFIG3]]
+//      CHECK:   linalg.pack
+// CHECK-SAME:       lowering_config = #[[CONFIG4]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "generic", cpu_features = "+sve", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-elf", ukernels = false}>
+#map_tp0 = affine_map<(d0, d1) -> (d1, d0)>
+#map_tp1 = affine_map<(d0, d1) -> (d0, d1)>
+#map_tp_outer = affine_map<()[s0] -> (30522 ceildiv s0)>
+func.func @transpose_pack(%arg0: tensor<30522x768xf32>) -> tensor<?x96x?x8xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %c16_vscale = arith.muli %vscale, %c16 : index
+  %outer_dynamic = affine.apply #map_tp_outer()[%c16_vscale]
+  %empty0 = tensor.empty() : tensor<768x30522xf32>
+  %0 = linalg.generic {indexing_maps = [#map_tp0, #map_tp1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<30522x768xf32>) outs(%empty0 : tensor<768x30522xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<768x30522xf32>
+  %empty1 = tensor.empty(%outer_dynamic, %c16_vscale) : tensor<?x96x?x8xf32>
+  %pack = linalg.pack %0 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16_vscale, 8] into %empty1 : tensor<768x30522xf32> -> tensor<?x96x?x8xf32>
+  return %pack : tensor<?x96x?x8xf32>
+}
+//  CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<distribution = [64, 64], vector_common_parallel = [8, [16]]>
+//  CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+//      CHECK: func.func @transpose_pack(
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       lowering_config = #[[CONFIG1]]
+//      CHECK:   linalg.pack
+// CHECK-SAME:       lowering_config = #[[CONFIG2]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "generic", cpu_features = "+sve", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-elf", ukernels = false}>
+#map_rp0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map_rp1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_rp_outer = affine_map<()[s0] -> (1024 ceildiv s0)>
+func.func @reduction_pack(%arg0: tensor<384x1024x32xf32>, %arg1: tensor<384x1024xf32>) -> tensor<?x24x16x?xf32> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {
+  %cst = arith.constant -0.000000e+00 : f32
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %c16_vscale = arith.muli %vscale, %c16 : index
+  %outer0 = affine.apply #map_rp_outer()[%c16_vscale]
+  %empty0 = tensor.empty() : tensor<384x1024xf32>
+  %empty2 = tensor.empty(%outer0, %c16_vscale) : tensor<?x24x16x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty0 : tensor<384x1024xf32>) -> tensor<384x1024xf32>
+  %generic = linalg.generic {indexing_maps = [#map_rp0, #map_rp1, #map_rp1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<384x1024x32xf32>, tensor<384x1024xf32>) outs(%fill : tensor<384x1024xf32>) {
+  ^bb0(%in: f32, %in0: f32, %out: f32):
+    %sub = arith.subf %in, %in0 : f32
+    %mul = arith.mulf %sub, %sub : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<384x1024xf32>
+  %pack = linalg.pack %generic outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, %c16_vscale] into %empty2 : tensor<384x1024xf32> -> tensor<?x24x16x?xf32>
+  return %pack : tensor<?x24x16x?xf32>
+}
+//  CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [16, [16]]>
+//  CHECK-DAG: #[[CONFIG2:.+]] = #iree_cpu.lowering_config<distribution = [32, 32, 0], vector_common_parallel = [16, [16], 0], vector_reduction = [0, 0, 4]>
+//  CHECK-DAG: #[[CONFIG3:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
+//      CHECK: func.func @reduction_pack(
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.fill
+// CHECK-SAME:       lowering_config = #[[CONFIG1]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       lowering_config = #[[CONFIG2]]
+//      CHECK:   linalg.pack
+// CHECK-SAME:       lowering_config = #[[CONFIG3]]
+
+// -----
+
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+v9a,+sve", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34"}>
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map2 = affine_map<()[s0] -> (10240 ceildiv s0)>
 func.func @mmt4d_generic_unpack_pack(%arg0: tensor<5x4096x16x1xf16>, %arg1: tensor<?x4096x?x1xf16>) -> tensor<5x10240x16x1xf16> attributes {hal.executable.target = #executable_target_system_elf_arm_64_} {