[Codegen] Support dynamic/scalable sizes when folding insert_slice into xfer_write (#17963)

This enables further optimizations which are currently missed when
targeting scalable vectors.

---------

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
index 7f67b76..be6fd7a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
@@ -5,6 +5,8 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -117,6 +119,63 @@
   }
 };
 
+/// Returns true if `writeOp` fully overwrites its destination.
+///
+/// Example:
+///
+/// ```
+/// vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]}
+///    : vector<4x5xf32>, tensor<4x5xf32>
+/// ```
+///
+/// This is an easy case, `vector<4x5xf32>` fully-overwrites `tensor<4x5xf32>`
+/// as the vector is the same size as the tensor. This check also supports
+/// dynamic tensors, where it resolves the tensor sizes via value-bounds
+/// analysis, and then checks if the vector type fully overwrites the tensor.
+static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) {
+  if (writeOp.hasOutOfBoundsDim())
+    return false;
+  if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank())
+    return false;
+  if (writeOp.getMask())
+    return false;
+
+  std::optional<iree_compiler::VscaleRange> vscaleRange;
+  auto vecType = writeOp.getVectorType();
+  if (vecType.isScalable()) {
+    auto targetAttr =
+        iree_compiler::IREE::HAL::ExecutableTargetAttr::lookup(writeOp);
+    vscaleRange = iree_compiler::getDefaultVscaleRange(targetAttr);
+  }
+
+  Value dest = writeOp.getSource();
+  ArrayRef<int64_t> destShape = writeOp.getShapedType().getShape();
+
+  // Attempts to resolve the size of a dim within the destination.
+  auto resolveDestinationDimSize =
+      [&](unsigned dimIndex) -> FailureOr<iree_compiler::DimBoundSize> {
+    auto size = destShape[dimIndex];
+    // Fixed-size dimensions are simply included in the shape.
+    if (size != ShapedType::kDynamic)
+      return iree_compiler::DimBoundSize{size};
+    // (Attempt to) resolve dynamic dimensions via value-bounds analysis.
+    return iree_compiler::computeDimUpperBound(dest, dimIndex, vscaleRange);
+  };
+
+  ArrayRef<int64_t> vecShape = vecType.getShape();
+  ArrayRef<bool> vecScalableFlags = vecType.getScalableDims();
+  for (unsigned d = 0, e = destShape.size(); d < e; ++d) {
+    auto dimSize = resolveDestinationDimSize(d);
+    if (failed(dimSize))
+      return false;
+    if (dimSize->scalable && !vecScalableFlags[d])
+      return false;
+    if (vecShape[d] != dimSize->baseSize)
+      return false;
+  }
+  return true;
+}
+
 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
 /// could directly write to the insert_slice's destination. E.g.:
 ///
@@ -150,20 +209,12 @@
     // TODO: support 0-d corner case.
     if (xferOp.getTransferRank() == 0)
       return failure();
-
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
-      return failure();
-    if (xferOp.getMask())
+    if (!xferOp.getPermutationMap().isIdentity())
       return failure();
     // Fold only if the TransferWriteOp completely overwrites the `source` with
     // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
     // content is the data of the vector.
-    if (!llvm::equal(xferOp.getVectorType().getShape(),
-                     xferOp.getShapedType().getShape()))
-      return failure();
-    if (!xferOp.getPermutationMap().isIdentity())
+    if (!isDestinationFullyOverwritten(xferOp))
       return failure();
 
     // Bail on illegal rank-reduction: we need to check that the rank-reduced
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
index b07b2b5..1255453 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
@@ -63,6 +63,115 @@
 
 // -----
 
+func.func @fold_insert_slice_into_transfer_write_static(%v: vector<4x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<4x5xf32>, tensor<4x5xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_static
+// CHECK-SAME:    %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]+]]
+// CHECK-NEXT:    %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x5xf32>, tensor<?x?xf32>
+// CHECK-NEXT:    return %[[WRITE]]
+
+// -----
+
+#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>
+
+func.func @fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[4]xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32>
+  attributes {hal.executable.target = #aarch64_sve}
+{
+  %vscale = vector.vscale
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c4_vscale = arith.muli %c4, %vscale : index
+  %extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
+  %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor<4x?xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_scalable
+// CHECK-SAME:    %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]+]]
+// CHECK-NEXT:    %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor<?x?xf32>
+// CHECK-NEXT:    return %[[WRITE]]
+
+// -----
+
+func.func @fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x8xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index, %size: index) -> tensor<?x?xf32>
+{
+  %c0 = arith.constant 0 : index
+  %slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size)
+  %extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
+  %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf32>, tensor<4x?xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_dynamic
+// CHECK-SAME:    %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[T2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]+]]
+// CHECK-NEXT:    %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32>
+// CHECK-NEXT:    return %[[WRITE]]
+
+// -----
+
+func.func @negative_fold_insert_slice_into_transfer_write_static(%v: vector<3x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf32>, tensor<4x5xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_static
+// CHECK: %[[WRITE:.*]] = vector.transfer_write
+// CHECK: tensor.insert_slice %[[WRITE]]
+
+// -----
+
+#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>
+
+func.func @negative_fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[2]xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32>
+  attributes {hal.executable.target = #aarch64_sve}
+{
+  %vscale = vector.vscale
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c4_vscale = arith.muli %c4, %vscale : index
+  %extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
+  %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[2]xf32>, tensor<4x?xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_scalable
+// CHECK: %[[WRITE:.*]] = vector.transfer_write
+// CHECK: tensor.insert_slice %[[WRITE]]
+
+// -----
+
+func.func @negative_fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x7xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index, %size: index) -> tensor<?x?xf32>
+{
+  %c0 = arith.constant 0 : index
+  %slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size)
+  %extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
+  %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x7xf32>, tensor<4x?xf32>
+  %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_dynamic
+// CHECK: %[[WRITE:.*]] = vector.transfer_write
+// CHECK: tensor.insert_slice %[[WRITE]]
+
+// -----
+
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -70,6 +179,7 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
+
 #map = affine_map<()[s0] -> (s0 * 64)>
 #map1 = affine_map<()[s0] -> (s0 * 128)>
 #map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>