Handle reshape ops in FlattenMemRefSubspanPass (#7649)
These are essentially no-op given that we'll flatten both the
source and target memref types.
diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index 7f86260..fb1fe5b 100644
--- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -231,6 +231,10 @@
}
};
+//===----------------------------------------------------------------------===//
+// Linearizing Patterns
+//===----------------------------------------------------------------------===//
+
/// Generates IR to perform index linearization with the given `indices`
/// indexing into the given memref `sourceValue`.
static Value linearizeIndices(Value sourceValue, ValueRange indices,
@@ -312,7 +316,7 @@
ConversionPatternRewriter &rewriter) const override {
if (!isRankOneMemRef(adaptor.memref().getType())) {
return rewriter.notifyMatchFailure(
- loadOp, "expected converted memref of rank <= 1");
+ loadOp, "expected converted memref of rank == 1");
}
Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(),
@@ -337,7 +341,7 @@
ConversionPatternRewriter &rewriter) const override {
if (!isRankOneMemRef(adaptor.memref().getType())) {
return rewriter.notifyMatchFailure(
- storeOp, "expected converted memref of rank <= 1");
+ storeOp, "expected converted memref of rank == 1");
}
Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(),
@@ -366,7 +370,7 @@
}
if (!isRankOneMemRef(adaptor.source().getType())) {
return rewriter.notifyMatchFailure(
- transferReadOp, "expected converted memref of rank <= 1");
+ transferReadOp, "expected converted memref of rank == 1");
}
Value linearIndex =
linearizeIndices(transferReadOp.source(), transferReadOp.indices(),
@@ -397,7 +401,7 @@
}
if (!isRankOneMemRef(adaptor.source().getType())) {
return rewriter.notifyMatchFailure(
- transferWriteOp, "expected converted memref of rank <= 1");
+ transferWriteOp, "expected converted memref of rank == 1");
}
Value linearIndex =
linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(),
@@ -429,7 +433,7 @@
if (!isRankOneMemRef(input.getType())) {
return rewriter.notifyMatchFailure(
- castOp, "expected converted memref of rank <= 1");
+ castOp, "expected converted memref of rank == 1");
}
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
castOp, castOp.getResultTypes(), input);
@@ -441,6 +445,24 @@
// Folding Patterns
//===----------------------------------------------------------------------===//
+/// Removes MemRef reshape ops given that we'll linearize both the source and
+/// target type to the same one.
+template <typename ReshapeOpTy>
+struct FoldMemRefReshape final : public OpConversionPattern<ReshapeOpTy> {
+ using OpConversionPattern<ReshapeOpTy>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ ReshapeOpTy op, typename ReshapeOpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankOneMemRef(adaptor.src().getType())) {
+ return rewriter.notifyMatchFailure(
+ op, "expected converted memref of rank == 1");
+ }
+ rewriter.replaceOp(op, adaptor.src());
+ return success();
+ };
+};
+
/// Returns the number of bytes of the given `type`. Returns llvm::None if
/// cannot deduce.
///
@@ -551,15 +573,17 @@
FlattenGlobal, FlattenGetGlobal, FlattenBindingSubspan,
LinearizeLoadIndices, LinearizeStoreIndices,
LinearizeTransferReadIndices, LinearizeTransferWriteIndices,
- AdjustConversionCast>(typeConverter, &context);
+ AdjustConversionCast, FoldMemRefReshape<memref::CollapseShapeOp>,
+ FoldMemRefReshape<memref::ExpandShapeOp>>(typeConverter, &context);
ConversionTarget target(context);
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp,
- memref::AllocaOp, memref::AllocOp,
- memref::GetGlobalOp>([](Operation *op) {
- return isRankOneMemRef(op->getResultTypes().front());
- });
+ target.addDynamicallyLegalOp<
+ IREE::HAL::InterfaceBindingSubspanOp, memref::AllocaOp, memref::AllocOp,
+ memref::CollapseShapeOp, memref::ExpandShapeOp, memref::GetGlobalOp>(
+ [](Operation *op) {
+ return isRankOneMemRef(op->getResultTypes().front());
+ });
target.addDynamicallyLegalOp<memref::GlobalOp>(
[](memref::GlobalOp op) { return isRankOneMemRef(op.type()); });
target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp loadOp) {
diff --git a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
index a76fca2..fdf8333 100644
--- a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
+++ b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
@@ -331,3 +331,47 @@
// CHECK: %[[LOAD:.+]] = memref.load %[[SPAN0]][%[[INDEX0]]] : memref<?xf32>
// CHECK: %[[INDEX1:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]]]
// CHECK: memref.store %[[LOAD]], %[[SPAN1]][%[[INDEX1]]] : memref<?xf32>
+
+// -----
+
+func @collapse_shape(%offset : index, %i0 : index, %i1 : index) -> f32 {
+ %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<4x5x6x7xf32>
+ %collapse = memref.collapse_shape %subspan[[0, 1], [2, 3]] : memref<4x5x6x7xf32> into memref<20x42xf32>
+ %value = memref.load %collapse[%i0, %i1] : memref<20x42xf32>
+ return %value : f32
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 42 + s1 + s2 floordiv 4)>
+// CHECK: func @collapse_shape
+// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index)
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SIZE:.+]] = arith.constant 840 : index
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[OFFSET]]]
+// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]]
+
+// -----
+
+func @expand_shape(%offset : index, %i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
+ %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<20x42xf32>
+ %expand = memref.expand_shape %subspan[[0, 1], [2, 3]] : memref<20x42xf32> into memref<4x5x6x7xf32>
+ %value = memref.load %expand[%i0, %i1, %i2, %i3] : memref<4x5x6x7xf32>
+ return %value : f32
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 210 + s1 * 42 + s2 * 7 + s3 + s4 floordiv 4)>
+// CHECK: func @expand_shape
+// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SIZE:.+]] = arith.constant 840 : index
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[OFFSET]]]
+// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]]