[LLVMGPU] Enhance TensorPad pass to handle tensor.unpack ops. (#12458)

The unpack op can have extract_slice semantics. The pass pads the
tensor.unpack ops that gets rid of extract_slice. This helps vector
transfer ops be in bounds.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
index 171fd17..a2847a1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
@@ -22,6 +22,24 @@
 
 namespace {
 
+static FailureOr<SmallVector<int64_t>> getPaddedShapeFromTensorLoad(
+    IREE::Flow::DispatchTensorLoadOp tensorLoad, ArrayRef<int64_t> origShape) {
+  // Determine the padded shape from the load.
+  SmallVector<int64_t> paddedShape(origShape.begin(), origShape.end());
+  for (const auto &[index, size] :
+       llvm::enumerate(tensorLoad.getMixedSizes())) {
+    if (Optional<int64_t> cst = getConstantIntValue(size)) {
+      paddedShape[index] = cst.value();
+    } else {
+      FailureOr<int64_t> upperBound =
+          linalg::getConstantUpperBoundForIndex(size.get<Value>());
+      if (failed(upperBound)) return failure();
+      paddedShape[index] = *upperBound;
+    }
+  }
+  return paddedShape;
+}
+
 static FailureOr<SmallVector<Value>> rewriteAsPaddedOp(
     IRRewriter &rewriter, linalg::LinalgOp linalgOp,
     linalg::LinalgOp &paddedOp) {
@@ -38,30 +56,15 @@
   paddedOperands.reserve(linalgOp.getNumDpsInputs() +
                          linalgOp.getNumDpsInits());
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
-    // Find DispatchTensorLoadOp's feeding into the linalg or abort.
-    auto tensorLoad = dyn_cast_or_null<IREE::Flow::DispatchTensorLoadOp>(
-        opOperand.get().getDefiningOp());
+    auto tensorLoad =
+        opOperand.get().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
     if (!tensorLoad) {
       return rewriter.notifyMatchFailure(linalgOp, "does not have tensor load");
     }
-
-    // Determine the padded shape from the load
-    ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
-    SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
-    for (const auto &[index, size] :
-         llvm::enumerate(tensorLoad.getMixedSizes())) {
-      if (Optional<int64_t> cst = getConstantIntValue(size)) {
-        paddedShape[index] = cst.value();
-      } else {
-        FailureOr<int64_t> upperBound =
-            linalg::getConstantUpperBoundForIndex(size.get<Value>());
-        if (failed(upperBound)) {
-          return rewriter.notifyMatchFailure(
-              linalgOp, "No constant bounding box can be found for padding");
-        }
-        paddedShape[index] = *upperBound;
-      }
-    }
+    FailureOr<SmallVector<int64_t>> maybePaddedShape =
+        getPaddedShapeFromTensorLoad(tensorLoad, linalgOp.getShape(&opOperand));
+    if (failed(maybePaddedShape)) return failure();
+    auto paddedShape = *maybePaddedShape;
 
     Value paddingValue = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(getElementTypeOrSelf(tensorLoad)));
@@ -104,6 +107,63 @@
   return paddedSubviewResults;
 }
 
+static FailureOr<Value> rewriteAsPaddedOp(IRRewriter &rewriter,
+                                          tensor::UnPackOp op,
+                                          tensor::UnPackOp &paddedOp) {
+  Location loc = op.getLoc();
+
+  // Set IP after op because we also take the dims of the original output.
+  IRRewriter::InsertionGuard g(rewriter);
+  rewriter.setInsertionPointAfter(op);
+  auto tensorLoad =
+      op.getDest().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+  if (!tensorLoad) {
+    return failure();
+  }
+
+  FailureOr<SmallVector<int64_t>> maybePaddedShape =
+      getPaddedShapeFromTensorLoad(tensorLoad, op.getDestType().getShape());
+  if (failed(maybePaddedShape)) return failure();
+  auto paddedShape = *maybePaddedShape;
+
+  // Pad to the shape that makes tensor.unpack ops produce full tiles.
+  SmallVector<int64_t> innerTiles = op.getStaticTiles();
+  ArrayRef<int64_t> dimPos = op.getInnerDimsPos();
+  for (auto [pos, size] : llvm::zip_equal(dimPos, innerTiles)) {
+    paddedShape[pos] = llvm::divideCeil(paddedShape[pos], size) * size;
+  }
+
+  Value paddingValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(getElementTypeOrSelf(tensorLoad)));
+  auto paddedTensorType =
+      RankedTensorType::get(paddedShape, getElementTypeOrSelf(tensorLoad));
+  Value paddedValue = linalg::makeComposedPadHighOp(
+      rewriter, loc, paddedTensorType, tensorLoad, paddingValue,
+      /*nofold=*/false);
+
+  SmallVector<Value> paddedOperands = {op.getSource(), paddedValue};
+  paddedOperands.append(op.getInnerTiles().begin(), op.getInnerTiles().end());
+  paddedOp = rewriter.create<tensor::UnPackOp>(
+      loc, TypeRange{paddedValue.getType()}, paddedOperands, op->getAttrs());
+
+  // Slice out the original shape from the padded result to pass on to
+  // consumers.
+  SmallVector<SmallVector<Value>> reifiedResultShapes;
+  if (failed(op.reifyResultShapes(rewriter, reifiedResultShapes))) {
+    return failure();
+  }
+
+  Value paddedSubviewResults;
+  int64_t rank = paddedOp.getDestRank();
+  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(ValueRange(reifiedResultShapes[0]));
+  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+  paddedSubviewResults = rewriter.create<tensor::ExtractSliceOp>(
+      loc, paddedOp.getResult(), offsets, sizes, strides);
+  return paddedSubviewResults;
+}
+
 static bool hasTwoOrThreeLoopsInfo(linalg::LinalgOp linalgOp) {
   return linalgOp.getNumParallelLoops() >= 2 &&
          linalgOp.getNumParallelLoops() <= 3;
@@ -139,6 +199,18 @@
       // Replace the original operation to pad.
       rewriter.replaceOp(linalgOp, *newResults);
     });
+
+    funcOp.walk([&](tensor::UnPackOp unpackOp) {
+      tensor::UnPackOp paddedOp;
+      FailureOr<Value> newResult =
+          rewriteAsPaddedOp(rewriter, unpackOp, paddedOp);
+      if (failed(newResult)) {
+        return;
+      }
+
+      // Replace the original operation to pad.
+      rewriter.replaceOp(unpackOp, *newResult);
+    });
   }
 };
 }  // namespace
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
index 67a75c5..0aba90f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
@@ -64,3 +64,58 @@
 //       CHECK:      } -> tensor<32x32xf32>
 //       CHECK:      %[[D13:.*]] = tensor.extract_slice %[[D12:.*]][0, 0] {{\[}}%[[D4]], 32] [1, 1] : tensor<32x32xf32> to tensor<?x32xf32>
 //       CHECK:      flow.dispatch.tensor.store %[[D13]], %[[D1]], offsets = {{\[}}%[[ARG0]], %[[ARG1]]], sizes = {{\[}}%[[D4]], 32], strides = [1, 1] : tensor<?x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x32xf32>>
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 * 16)>
+#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+#map2 = affine_map<(d0) -> (d0 ceildiv 2)>
+#map3 = affine_map<(d0) -> (d0 floordiv 2)>
+func.func @unpack_dynamic() {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %0 = hal.interface.constant.load[0] : i32
+  %1 = hal.interface.constant.load[1] : i32
+  %2 = hal.interface.constant.load[2] : i32
+  %3 = hal.interface.constant.load[3] : i32
+  %4 = arith.index_castui %0 : i32 to index
+  %5 = arith.index_castui %1 : i32 to index
+  %6 = arith.index_castui %2 : i32 to index
+  %7 = arith.index_castui %3 : i32 to index
+  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x2x2xi32>>{%4, %5}
+  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  %10 = affine.apply #map()[%workgroup_id_y]
+  %11 = affine.apply #map()[%workgroup_count_y]
+  scf.for %arg0 = %10 to %6 step %11 {
+    %12 = affine.min #map1(%arg0)[%6]
+    %13 = affine.apply #map()[%workgroup_id_x]
+    %14 = affine.apply #map()[%workgroup_count_x]
+    scf.for %arg1 = %13 to %7 step %14 {
+      %15 = affine.min #map1(%arg1)[%7]
+      %16 = flow.dispatch.tensor.load %9, offsets = [%arg0, %arg1], sizes = [%12, %15], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7} -> tensor<?x?xi32>
+      %17 = affine.apply #map2(%12)
+      %18 = affine.apply #map2(%15)
+      %19 = affine.apply #map3(%arg0)
+      %20 = affine.apply #map3(%arg1)
+      %21 = flow.dispatch.tensor.load %8, offsets = [%19, %20, 0, 0], sizes = [%17, %18, 2, 2], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x2x2xi32>>{%4, %5} -> tensor<?x?x2x2xi32>
+      %c16 = arith.constant 16 : index
+      %c0_i32 = arith.constant 0 : i32
+      %22 = arith.subi %c16, %12 : index
+      %23 = arith.subi %c16, %15 : index
+      %unpack = tensor.unpack %21 inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %16 : tensor<?x?x2x2xi32> -> tensor<?x?xi32>
+      flow.dispatch.tensor.store %unpack, %9, offsets = [%arg0, %arg1], sizes = [%12, %15], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
+    }
+  }
+  return
+}
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK:         %[[DEST_BUF:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK:           %[[LOAD:.+]] = flow.dispatch.tensor.load %[[DEST_BUF]]
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[LOAD]]
+// CHECK:           %[[UNPACK:.+]] = tensor.unpack {{.+}} into %[[PAD]]
+// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK:           flow.dispatch.tensor.store %[[SLICE]], %[[DEST_BUF]]