[LLVMGPU] Enhance TensorPad pass to handle tensor.unpack ops. (#12458)
The unpack op can have extract_slice semantics. The pass pads the
tensor.unpack ops that gets rid of extract_slice. This helps vector
transfer ops be in bounds.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
index 171fd17..a2847a1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
@@ -22,6 +22,24 @@
namespace {
+static FailureOr<SmallVector<int64_t>> getPaddedShapeFromTensorLoad(
+ IREE::Flow::DispatchTensorLoadOp tensorLoad, ArrayRef<int64_t> origShape) {
+ // Determine the padded shape from the load.
+ SmallVector<int64_t> paddedShape(origShape.begin(), origShape.end());
+ for (const auto &[index, size] :
+ llvm::enumerate(tensorLoad.getMixedSizes())) {
+ if (Optional<int64_t> cst = getConstantIntValue(size)) {
+ paddedShape[index] = cst.value();
+ } else {
+ FailureOr<int64_t> upperBound =
+ linalg::getConstantUpperBoundForIndex(size.get<Value>());
+ if (failed(upperBound)) return failure();
+ paddedShape[index] = *upperBound;
+ }
+ }
+ return paddedShape;
+}
+
static FailureOr<SmallVector<Value>> rewriteAsPaddedOp(
IRRewriter &rewriter, linalg::LinalgOp linalgOp,
linalg::LinalgOp &paddedOp) {
@@ -38,30 +56,15 @@
paddedOperands.reserve(linalgOp.getNumDpsInputs() +
linalgOp.getNumDpsInits());
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
- // Find DispatchTensorLoadOp's feeding into the linalg or abort.
- auto tensorLoad = dyn_cast_or_null<IREE::Flow::DispatchTensorLoadOp>(
- opOperand.get().getDefiningOp());
+ auto tensorLoad =
+ opOperand.get().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!tensorLoad) {
return rewriter.notifyMatchFailure(linalgOp, "does not have tensor load");
}
-
- // Determine the padded shape from the load
- ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
- SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
- for (const auto &[index, size] :
- llvm::enumerate(tensorLoad.getMixedSizes())) {
- if (Optional<int64_t> cst = getConstantIntValue(size)) {
- paddedShape[index] = cst.value();
- } else {
- FailureOr<int64_t> upperBound =
- linalg::getConstantUpperBoundForIndex(size.get<Value>());
- if (failed(upperBound)) {
- return rewriter.notifyMatchFailure(
- linalgOp, "No constant bounding box can be found for padding");
- }
- paddedShape[index] = *upperBound;
- }
- }
+ FailureOr<SmallVector<int64_t>> maybePaddedShape =
+ getPaddedShapeFromTensorLoad(tensorLoad, linalgOp.getShape(&opOperand));
+ if (failed(maybePaddedShape)) return failure();
+ auto paddedShape = *maybePaddedShape;
Value paddingValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(getElementTypeOrSelf(tensorLoad)));
@@ -104,6 +107,63 @@
return paddedSubviewResults;
}
+static FailureOr<Value> rewriteAsPaddedOp(IRRewriter &rewriter,
+ tensor::UnPackOp op,
+ tensor::UnPackOp &paddedOp) {
+ Location loc = op.getLoc();
+
+ // Set IP after op because we also take the dims of the original output.
+ IRRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(op);
+ auto tensorLoad =
+ op.getDest().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!tensorLoad) {
+ return failure();
+ }
+
+ FailureOr<SmallVector<int64_t>> maybePaddedShape =
+ getPaddedShapeFromTensorLoad(tensorLoad, op.getDestType().getShape());
+ if (failed(maybePaddedShape)) return failure();
+ auto paddedShape = *maybePaddedShape;
+
+ // Pad to the shape that makes tensor.unpack ops produce full tiles.
+ SmallVector<int64_t> innerTiles = op.getStaticTiles();
+ ArrayRef<int64_t> dimPos = op.getInnerDimsPos();
+ for (auto [pos, size] : llvm::zip_equal(dimPos, innerTiles)) {
+ paddedShape[pos] = llvm::divideCeil(paddedShape[pos], size) * size;
+ }
+
+ Value paddingValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(getElementTypeOrSelf(tensorLoad)));
+ auto paddedTensorType =
+ RankedTensorType::get(paddedShape, getElementTypeOrSelf(tensorLoad));
+ Value paddedValue = linalg::makeComposedPadHighOp(
+ rewriter, loc, paddedTensorType, tensorLoad, paddingValue,
+ /*nofold=*/false);
+
+ SmallVector<Value> paddedOperands = {op.getSource(), paddedValue};
+ paddedOperands.append(op.getInnerTiles().begin(), op.getInnerTiles().end());
+ paddedOp = rewriter.create<tensor::UnPackOp>(
+ loc, TypeRange{paddedValue.getType()}, paddedOperands, op->getAttrs());
+
+ // Slice out the original shape from the padded result to pass on to
+ // consumers.
+ SmallVector<SmallVector<Value>> reifiedResultShapes;
+ if (failed(op.reifyResultShapes(rewriter, reifiedResultShapes))) {
+ return failure();
+ }
+
+ Value paddedSubviewResults;
+ int64_t rank = paddedOp.getDestRank();
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes =
+ getAsOpFoldResult(ValueRange(reifiedResultShapes[0]));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ paddedSubviewResults = rewriter.create<tensor::ExtractSliceOp>(
+ loc, paddedOp.getResult(), offsets, sizes, strides);
+ return paddedSubviewResults;
+}
+
static bool hasTwoOrThreeLoopsInfo(linalg::LinalgOp linalgOp) {
return linalgOp.getNumParallelLoops() >= 2 &&
linalgOp.getNumParallelLoops() <= 3;
@@ -139,6 +199,18 @@
// Replace the original operation to pad.
rewriter.replaceOp(linalgOp, *newResults);
});
+
+ funcOp.walk([&](tensor::UnPackOp unpackOp) {
+ tensor::UnPackOp paddedOp;
+ FailureOr<Value> newResult =
+ rewriteAsPaddedOp(rewriter, unpackOp, paddedOp);
+ if (failed(newResult)) {
+ return;
+ }
+
+ // Replace the original operation to pad.
+ rewriter.replaceOp(unpackOp, *newResult);
+ });
}
};
} // namespace
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
index 67a75c5..0aba90f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_pad.mlir
@@ -64,3 +64,58 @@
// CHECK: } -> tensor<32x32xf32>
// CHECK: %[[D13:.*]] = tensor.extract_slice %[[D12:.*]][0, 0] {{\[}}%[[D4]], 32] [1, 1] : tensor<32x32xf32> to tensor<?x32xf32>
// CHECK: flow.dispatch.tensor.store %[[D13]], %[[D1]], offsets = {{\[}}%[[ARG0]], %[[ARG1]]], sizes = {{\[}}%[[D4]], 32], strides = [1, 1] : tensor<?x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x32xf32>>
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 * 16)>
+#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+#map2 = affine_map<(d0) -> (d0 ceildiv 2)>
+#map3 = affine_map<(d0) -> (d0 floordiv 2)>
+func.func @unpack_dynamic() {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %0 = hal.interface.constant.load[0] : i32
+ %1 = hal.interface.constant.load[1] : i32
+ %2 = hal.interface.constant.load[2] : i32
+ %3 = hal.interface.constant.load[3] : i32
+ %4 = arith.index_castui %0 : i32 to index
+ %5 = arith.index_castui %1 : i32 to index
+ %6 = arith.index_castui %2 : i32 to index
+ %7 = arith.index_castui %3 : i32 to index
+ %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x2x2xi32>>{%4, %5}
+ %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %10 = affine.apply #map()[%workgroup_id_y]
+ %11 = affine.apply #map()[%workgroup_count_y]
+ scf.for %arg0 = %10 to %6 step %11 {
+ %12 = affine.min #map1(%arg0)[%6]
+ %13 = affine.apply #map()[%workgroup_id_x]
+ %14 = affine.apply #map()[%workgroup_count_x]
+ scf.for %arg1 = %13 to %7 step %14 {
+ %15 = affine.min #map1(%arg1)[%7]
+ %16 = flow.dispatch.tensor.load %9, offsets = [%arg0, %arg1], sizes = [%12, %15], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7} -> tensor<?x?xi32>
+ %17 = affine.apply #map2(%12)
+ %18 = affine.apply #map2(%15)
+ %19 = affine.apply #map3(%arg0)
+ %20 = affine.apply #map3(%arg1)
+ %21 = flow.dispatch.tensor.load %8, offsets = [%19, %20, 0, 0], sizes = [%17, %18, 2, 2], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x2x2xi32>>{%4, %5} -> tensor<?x?x2x2xi32>
+ %c16 = arith.constant 16 : index
+ %c0_i32 = arith.constant 0 : i32
+ %22 = arith.subi %c16, %12 : index
+ %23 = arith.subi %c16, %15 : index
+ %unpack = tensor.unpack %21 inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %16 : tensor<?x?x2x2xi32> -> tensor<?x?xi32>
+ flow.dispatch.tensor.store %unpack, %9, offsets = [%arg0, %arg1], sizes = [%12, %15], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
+ }
+ }
+ return
+}
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK: %[[DEST_BUF:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[DEST_BUF]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[LOAD]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack {{.+}} into %[[PAD]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: flow.dispatch.tensor.store %[[SLICE]], %[[DEST_BUF]]