Allow dynamic dimensions during folding of `tensor.expand_shape/collapse_shape` into `flow.dispatch.tensor.load/store`. (#18873)
This also cleans up the implementation of these patterns to avoid using
templated code that is hard to read/maintain.
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
index 50ca569..6f1cc19 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
@@ -71,26 +71,36 @@
// -----
#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
-// CHECK-LABEL: func.func @dont_fold_dynamic_reshape()
-func.func @dont_fold_dynamic_reshape() {
+func.func @fold_dynamic_reshape() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
%dim1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
%dim2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<?x?x96xf32>>{%dim0, %dim1}
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
%3 = flow.dispatch.tensor.load %1, offsets=[0, 0, 0], sizes =[%dim0, %dim1, 96], strides=[1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x96xf32>>{%dim0, %dim1} -> tensor<?x?x96xf32>
- // CHECK: tensor.collapse_shape
- // CHECK: tensor.expand_shape
%4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<?x?x96xf32> into tensor<?x96xf32>
%dyn = tensor.dim %4, %c0 : tensor<?x96xf32>
%5 = tensor.expand_shape %4 [[0], [1, 2]] output_shape [%dyn, 12, 8] : tensor<?x96xf32> into tensor<?x12x8xf32>
- flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, 12, 8], strides = [%c1, %c1, %c1] : tensor<?x12x8xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
+ flow.dispatch.tensor.store %5, %2, offsets = [0, 0, 0], sizes = [%dim2, 12, 8], strides = [1, 1, 1] : tensor<?x12x8xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
return
}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func.func @fold_dynamic_reshape()
+// CHECK-DAG: %[[CST0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0)
+// CHECK-DAG: %[[CST1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1)
+// CHECK-DAG: %[[CST2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(2)
+// CHECK: %[[COLLAPSED:.+]] = affine.apply #[[MAP]]()[%[[CST0]], %[[CST1]]]
+// CHECK: %[[IN_BINDING:.+]] = hal.interface.binding.subspan
+// CHECK-SAME: binding(0) : !flow.dispatch.tensor<readonly:tensor<?x96xf32>>{%[[COLLAPSED]]}
+// CHECK: %[[OUT_BINDING:.+]] = hal.interface.binding.subspan
+// CHECK-SAME: binding(1) : !flow.dispatch.tensor<writeonly:tensor<?x96xf32>>{%[[CST2]]}
+// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[IN_BINDING]]
+// CHECK: flow.dispatch.tensor.store %[[IN]], %[[OUT_BINDING]]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
index 7dd745e..fc9e85e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s
func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> {
%collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16>
@@ -14,3 +14,75 @@
// CHECK: linalg.copy
// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
// CHECK-SAME: ins(%[[COLLAPSE]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+ #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
+func.func @fold_collapse_into_loads_dynamic() -> tensor<?x32xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+ flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0}
+ %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0} -> tensor<2x?x32xf32>
+ %3 = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor<?x32xf32>
+ return %3 : tensor<?x32xf32>
+}
+// CHECK-LABEL: func @fold_collapse_into_loads_dynamic()
+// CHECK: %[[CONST:.+]] = hal.interface.constant.load
+// CHECK: %[[SHAPE:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[CONST]]]
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x32xf32>>{%[[SHAPE]]}
+// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]]
+// CHECK-SAME: offsets = [0, 0], sizes = [%[[SHAPE]], 32], strides = [1, 1]
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x32xf32>>{%[[SHAPE]]}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
+ #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
+func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+ flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0}
+ %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0} -> tensor<2x?x32xf32>
+ %3 = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%0]
+ %4 = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [2, %3, 16, 32] : tensor<2x?x32xf32> into tensor<2x?x16x32xf32>
+ return %4 : tensor<2x?x16x32xf32>
+}
+// CHECK-LABEL: func @fold_expand_into_loads_dynamic()
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
+// CHECK-DAG: %[[CONST:.+]] = hal.interface.constant.load
+// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C16]]
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
+// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [2, %[[SHAPE]], 16, 32], strides = [1, 1, 1, 1]
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+ #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
+func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+ flags("ReadOnly|Indirect") : !flow.dispatch.tensor<writeonly:tensor<?x32xf32>>{%0}
+ %2 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor<?x32xf32>
+ flow.dispatch.tensor.store %2, %1, offsets = [0, 0], sizes = [%0, 32], strides = [1, 1]
+ : tensor<?x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x32xf32>>{%0}
+ return
+}
+// CHECK-LABEL: func @fold_collapse_into_stores_dynamic(
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[CONST:.+]] = hal.interface.constant.load
+// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C2]]
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}
+// CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]]
+// CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1]
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 0b8c49c..6d0d052 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -36,20 +36,29 @@
namespace mlir::iree_compiler {
-static bool isAllConstantValue(SmallVector<OpFoldResult> ofrs, int64_t v) {
+static bool isAllConstantValue(ArrayRef<OpFoldResult> ofrs, int64_t v) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); });
}
-static bool isFullSlice(SmallVector<OpFoldResult> mixedOffsets,
- SmallVector<OpFoldResult> mixedSizes,
- SmallVector<OpFoldResult> mixedStrides,
- IREE::Flow::DispatchTensorType tensorType) {
- std::optional<SmallVector<int64_t>> constSizes =
- getConstantIntValues(mixedSizes);
+static bool isFullSlice(ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides,
+ IREE::Flow::DispatchTensorType tensorType,
+ ValueRange dynamicDims) {
+ OpBuilder builder(tensorType.getContext());
+ SmallVector<int64_t> tensorShape = llvm::to_vector(tensorType.getShape());
+ SmallVector<OpFoldResult> mixedTensorShape =
+ mlir::getMixedValues(tensorShape, dynamicDims, builder);
return isAllConstantValue(mixedOffsets, 0) &&
- isAllConstantValue(mixedStrides, 1) && constSizes &&
- llvm::equal(tensorType.getShape(), *constSizes);
+ isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes;
+}
+static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp,
+ IREE::Flow::DispatchTensorType tensorType,
+ ValueRange dynamicDims) {
+ return isFullSlice(
+ sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(),
+ sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims);
}
static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands,
@@ -546,14 +555,29 @@
namespace {
-// TODO(antigainst): enable dynamic shape support once they are needed.
-template <typename TensorReshapeOp>
-static std::optional<Value> getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) {
- auto reshapeSrcType = llvm::cast<ShapedType>(reshapeOp.getSrc().getType());
- auto reshapeDstType = llvm::cast<ShapedType>(reshapeOp.getType());
- if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape())
- return std::nullopt;
- return reshapeOp.getSrc();
+static SmallVector<OpFoldResult>
+inferCollapsedShape(RewriterBase &rewriter, Location loc,
+ RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociations,
+ ValueRange expandedDynamicDims) {
+ ArrayRef<int64_t> expandedStaticShape = expandedType.getShape();
+ SmallVector<OpFoldResult> expandedMixedShape =
+ mlir::getMixedValues(expandedStaticShape, expandedDynamicDims, rewriter);
+ SmallVector<OpFoldResult> collapsedShape;
+ unsigned expandedShapeDim = 0;
+ for (auto reassociation : reassociations) {
+ AffineExpr mulExpr = rewriter.getAffineSymbolExpr(0);
+ for (auto i : llvm::seq<unsigned>(1, reassociation.size())) {
+ mulExpr = mulExpr * rewriter.getAffineSymbolExpr(i);
+ }
+ auto collapsedDim = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, mulExpr,
+ ArrayRef(expandedMixedShape)
+ .slice(expandedShapeDim, reassociation.size()));
+ collapsedShape.push_back(collapsedDim);
+ expandedShapeDim += reassociation.size();
+ }
+ return collapsedShape;
}
/// Folds tensor.expand/collapse_shape into the source
@@ -576,35 +600,38 @@
/// !flow.dispatch.tensor<readonly:tensor<864xf32>>
/// %0 = flow.dispatch.tensor.load %subspan :
/// !flow.dispatch.tensor<readonly:tensor<864xf32>> -> tensor<864xf32>
-template <typename TensorReshapeOp>
-struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {
- using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+struct FoldCollapseShapeIntoInterfaceTensorLoad
+ : OpRewritePattern<tensor::CollapseShapeOp> {
+ using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
- std::optional<Value> reshapeSrc =
- getStaticReshapeOpSrc<TensorReshapeOp>(reshapeOp);
- if (!reshapeSrc)
- return failure();
-
- auto loadOp =
- reshapeSrc->template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ Value reshapeSrc = reshapeOp.getSrc();
+ auto reshapeSrcType = cast<RankedTensorType>(reshapeSrc.getType());
+ auto loadOp = reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!loadOp)
return failure();
// Make sure we are loading the full incoming subspan. Otherwise we cannot
// simply adjust the subspan's resultant type later.
- if (!isFullSlice(loadOp.getMixedOffsets(), loadOp.getMixedSizes(),
- loadOp.getMixedStrides(), loadOp.getSourceType())) {
+ if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
return failure();
}
- auto subspanOp =
- loadOp.getSource()
- .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ auto subspanOp = loadOp.getSource()
+ .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!subspanOp)
return failure();
- assert(subspanOp.getDynamicDims().empty());
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(subspanOp);
+ SmallVector<OpFoldResult> collapsedShape = inferCollapsedShape(
+ rewriter, subspanOp.getLoc(), reshapeSrcType,
+ reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims());
+ SmallVector<int64_t> collapsedStaticShape;
+ SmallVector<Value> collapsedDynamicShape;
+ dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape,
+ collapsedStaticShape);
auto tensorAccess =
llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
@@ -615,12 +642,111 @@
Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
subspanOp.getBinding(), subspanOp.getByteOffset(),
- subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
+ collapsedDynamicShape, subspanOp.getAlignmentAttr(),
subspanOp.getDescriptorFlagsAttr());
+ rewriter.setInsertionPoint(reshapeOp);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
reshapeOp, reshapeOp.getResultType(), newSubspanOp,
- loadOp.getSourceDims());
+ collapsedDynamicShape);
+
+ return success();
+ }
+};
+
+/// Folds tensor.expand_shape into the source
+/// hal.interface.binding.subspan.
+///
+/// For example, this matches the following pattern:
+///
+/// %subspan = hal.interface.binding.subspan ... :
+/// !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>>
+/// %tensor = flow.dispatch.tensor.load %subspan :
+/// !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>> ->
+/// tensor<3x3x1x96xf32>
+/// %0 = linalg.expand_reshape %tensor [
+/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// ] : tensor<3x3x1x96xf32> into tensor<864xf32>
+///
+/// And turns it into:
+///
+/// %subspan = hal.interface.binding.subspan ... :
+/// !flow.dispatch.tensor<readonly:tensor<864xf32>>
+/// %0 = flow.dispatch.tensor.load %subspan :
+/// !flow.dispatch.tensor<readonly:tensor<864xf32>> -> tensor<864xf32>
+struct FoldExpandShapeIntoInterfaceTensorLoad
+ : OpRewritePattern<tensor::ExpandShapeOp> {
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ Value reshapeSrc = reshapeOp.getSrc();
+ auto loadOp = reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!loadOp) {
+ return failure();
+ }
+
+ // Make sure we are loading the full incoming subspan. Otherwise we cannot
+ // simply adjust the subspan's resultant type later.
+ if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
+ return failure();
+ }
+
+ // In the corner case where the expand_shape is the source of a store, dont
+ // fold with the load. Instead fold with the store to reduce the
+ // dimensionality
+ if (reshapeOp->hasOneUse()) {
+ if (auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
+ *reshapeOp->getUsers().begin())) {
+ if (isFullSlice(storeOp, storeOp.getTargetType(),
+ storeOp.getTargetDims())) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "fold with store instead");
+ }
+ }
+ }
+
+ auto subspanOp = loadOp.getSource()
+ .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ if (!subspanOp)
+ return failure();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(subspanOp);
+
+ auto currDynamicDims = subspanOp.getDynamicDims();
+ auto currStaticDims = loadOp.getType().getShape();
+ auto currOfrDynamicDims =
+ mlir::getMixedValues(currStaticDims, currDynamicDims, rewriter);
+ std::optional<SmallVector<OpFoldResult>> expandedDims =
+ mlir::inferExpandShapeOutputShape(
+ rewriter, subspanOp.getLoc(), reshapeOp.getType(),
+ reshapeOp.getReassociationIndices(), currOfrDynamicDims);
+ if (!expandedDims) {
+ return reshapeOp.emitOpError("failure in expanded shape");
+ }
+
+ auto tensorAccess =
+ llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
+ .getAccess();
+ auto newSubspanType = IREE::Flow::DispatchTensorType::get(
+ tensorAccess, reshapeOp.getResultType());
+
+ SmallVector<Value> expandedDynamicDims;
+ SmallVector<int64_t> expandedStaticDims;
+ dispatchIndexOpFoldResults(expandedDims.value(), expandedDynamicDims,
+ expandedStaticDims);
+
+ Value newSubspanOp;
+ newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+ subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+ subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicDims,
+ subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
+
+ rewriter.setInsertionPoint(reshapeOp);
+ rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
+ reshapeOp, reshapeOp.getResultType(), newSubspanOp,
+ expandedDynamicDims);
return success();
}
@@ -652,8 +778,8 @@
PatternRewriter &rewriter) const override {
// Make sure we are storing the full incoming subspan. Otherwise we cannot
// simply adjust the subspan's resultant type later.
- if (!isFullSlice(storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
- storeOp.getMixedStrides(), storeOp.getTargetType())) {
+ if (!isFullSlice(storeOp, storeOp.getTargetType(),
+ storeOp.getTargetDims())) {
return failure();
}
@@ -662,38 +788,136 @@
return failure();
}
- // Dynamic shapes are currently unsupported.
- std::optional<Value> reshapeSrc =
- getStaticReshapeOpSrc<tensor::ExpandShapeOp>(reshapeOp);
- if (!reshapeSrc)
- return failure();
+ Value reshapeSrc = reshapeOp.getSrc();
+ // If the source is a `flow.dispatch.tensor.load`, fold with the load
+ // instead to reduce dimensionality of the problem
+ if (auto loadOp =
+ reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
+ if (isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
+ return rewriter.notifyMatchFailure(
+ storeOp, "fold expand_shape with load instead");
+ }
+ }
- auto subspanOp =
- storeOp.getTarget()
- .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ auto subspanOp = storeOp.getTarget()
+ .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!subspanOp)
return failure();
- assert(subspanOp.getDynamicDims().empty());
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(subspanOp);
+ SmallVector<OpFoldResult> collapsedShape = inferCollapsedShape(
+ rewriter, subspanOp.getLoc(), reshapeOp.getResultType(),
+ reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims());
+ SmallVector<int64_t> collapsedStaticShape;
+ SmallVector<Value> collapsedDynamicShape;
+ dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape,
+ collapsedStaticShape);
auto tensorAccess =
llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
.getAccess();
- auto newSubspanType = IREE::Flow::DispatchTensorType::get(
- tensorAccess, reshapeSrc->getType());
+ auto newSubspanType =
+ IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrc.getType());
- Value newSubspanOp;
- {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointAfter(subspanOp);
- newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
- subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
- subspanOp.getBinding(), subspanOp.getByteOffset(),
- subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
- subspanOp.getDescriptorFlagsAttr());
+ Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+ subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+ subspanOp.getBinding(), subspanOp.getByteOffset(),
+ collapsedDynamicShape, subspanOp.getAlignmentAttr(),
+ subspanOp.getDescriptorFlagsAttr());
+
+ rewriter.setInsertionPoint(storeOp);
+ rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
+ storeOp, reshapeSrc, newSubspanOp, collapsedDynamicShape);
+
+ return success();
+ }
+};
+
+/// Folds tensor.collapse_shape into the source hal.interface.binding.subspan.
+///
+/// For example, this matches the following pattern:
+///
+/// %subspan = hal.interface.binding.subspan ... :
+/// !flow.dispatch.tensor<writeonly:tensor<3x3x1x96xf32>>
+/// %0 = tensor.collapse_shape %tensor [[0, 1, 2, 3]]
+/// : tensor<3x?x?x96xf32> into tensor<?xf32>
+/// %tensor = flow.dispatch.tensor.store %0, %subspan :
+/// tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%dim}
+///
+/// And turns it into:
+///
+/// %subspan = hal.interface.binding.subspan ... :
+/// !flow.dispatch.tensor<writeonly:tensor<3x?x?x96xf32>>
+/// %0 = flow.dispatch.tensor.store %tensor, %subspan :
+/// tensor<3x?x?x96xf32> ->
+/// !flow.dispatch.tensor<writeonly:tensor<3x?x?x96xf32>>{%d0, %d1}
+///
+/// TODO: This handles full slices. The pattern below
+/// (`FoldCollapseShapeIntoTensorInsertSlice`) handles cases where the slic is
+/// not a full slice, but requires the shapes to be static. This pattern handles
+/// dynamic shapes as well. Combine the two (if possible, it isnt clear that it
+/// is possible)
+struct FoldCollapseShapeIntoInterfaceTensorStoreFullSlice
+ : OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
+ using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ // Make sure we are storing the full incoming subspan. Otherwise we cannot
+ // simply adjust the subspan's resultant type later.
+ if (!isFullSlice(storeOp, storeOp.getTargetType(),
+ storeOp.getTargetDims())) {
+ return failure();
}
+ auto reshapeOp =
+ storeOp.getValue().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!reshapeOp) {
+ return failure();
+ }
+ auto subspanOp = storeOp.getTarget()
+ .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ if (!subspanOp)
+ return failure();
+
+ Value reshapeSrc = reshapeOp.getSrc();
+ auto reshapeSrcType = cast<RankedTensorType>(reshapeSrc.getType());
+
+ // Compute the type and dynamic dims of the interface binding.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(subspanOp);
+ auto dynamicDims = subspanOp.getDynamicDims();
+ ArrayRef<int64_t> staticShape = reshapeOp.getType().getShape();
+ SmallVector<OpFoldResult> mixedShape =
+ mlir::getMixedValues(staticShape, dynamicDims, rewriter);
+ std::optional<SmallVector<OpFoldResult>> expandedShape =
+ mlir::inferExpandShapeOutputShape(
+ rewriter, subspanOp.getLoc(),
+ cast<ShapedType>(reshapeSrc.getType()),
+ reshapeOp.getReassociationIndices(), mixedShape);
+ if (!expandedShape) {
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to compute expand shape for interface binding");
+ }
+ SmallVector<int64_t> expandedStaticShape;
+ SmallVector<Value> expandedDynamicShape;
+ dispatchIndexOpFoldResults(*expandedShape, expandedDynamicShape,
+ expandedStaticShape);
+
+ auto tensorAccess =
+ cast<IREE::Flow::DispatchTensorType>(subspanOp.getType()).getAccess();
+ auto newSubspanType =
+ IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrcType);
+
+ auto newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+ subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+ subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicShape,
+ subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
+
+ rewriter.setInsertionPoint(storeOp);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
- storeOp, *reshapeSrc, newSubspanOp, storeOp.getTargetDims());
+ storeOp, reshapeSrc, newSubspanOp, expandedDynamicShape);
return success();
}
@@ -840,12 +1064,11 @@
} // namespace
void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) {
- patterns.insert<FoldReshapeIntoInterfaceTensorLoad<tensor::CollapseShapeOp>,
- FoldReshapeIntoInterfaceTensorLoad<tensor::ExpandShapeOp>>(
- patterns.getContext());
- patterns.insert<FoldExpandShapeIntoInterfaceTensorStore>(
- patterns.getContext());
- patterns.insert<FoldCollapseShapeIntoInterfaceTensorStore>(
+ patterns.insert<FoldCollapseShapeIntoInterfaceTensorLoad,
+ FoldCollapseShapeIntoInterfaceTensorStore,
+ FoldCollapseShapeIntoInterfaceTensorStoreFullSlice,
+ FoldExpandShapeIntoInterfaceTensorLoad,
+ FoldExpandShapeIntoInterfaceTensorStore>(
patterns.getContext());
}