Handle reshape ops in FlattenMemRefSubspanPass (#7649)

These are essentially no-op given that we'll flatten both the
source and target memref types.
diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index 7f86260..fb1fe5b 100644
--- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -231,6 +231,10 @@
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Linearizing Patterns
+//===----------------------------------------------------------------------===//
+
 /// Generates IR to perform index linearization with the given `indices`
 /// indexing into the given memref `sourceValue`.
 static Value linearizeIndices(Value sourceValue, ValueRange indices,
@@ -312,7 +316,7 @@
       ConversionPatternRewriter &rewriter) const override {
     if (!isRankOneMemRef(adaptor.memref().getType())) {
       return rewriter.notifyMatchFailure(
-          loadOp, "expected converted memref of rank <= 1");
+          loadOp, "expected converted memref of rank == 1");
     }
 
     Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(),
@@ -337,7 +341,7 @@
       ConversionPatternRewriter &rewriter) const override {
     if (!isRankOneMemRef(adaptor.memref().getType())) {
       return rewriter.notifyMatchFailure(
-          storeOp, "expected converted memref of rank <= 1");
+          storeOp, "expected converted memref of rank == 1");
     }
 
     Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(),
@@ -366,7 +370,7 @@
     }
     if (!isRankOneMemRef(adaptor.source().getType())) {
       return rewriter.notifyMatchFailure(
-          transferReadOp, "expected converted memref of rank <= 1");
+          transferReadOp, "expected converted memref of rank == 1");
     }
     Value linearIndex =
         linearizeIndices(transferReadOp.source(), transferReadOp.indices(),
@@ -397,7 +401,7 @@
     }
     if (!isRankOneMemRef(adaptor.source().getType())) {
       return rewriter.notifyMatchFailure(
-          transferWriteOp, "expected converted memref of rank <= 1");
+          transferWriteOp, "expected converted memref of rank == 1");
     }
     Value linearIndex =
         linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(),
@@ -429,7 +433,7 @@
 
     if (!isRankOneMemRef(input.getType())) {
       return rewriter.notifyMatchFailure(
-          castOp, "expected converted memref of rank <= 1");
+          castOp, "expected converted memref of rank == 1");
     }
     rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
         castOp, castOp.getResultTypes(), input);
@@ -441,6 +445,24 @@
 // Folding Patterns
 //===----------------------------------------------------------------------===//
 
+/// Removes MemRef reshape ops given that we'll linearize both the source and
+/// target type to the same one.
+template <typename ReshapeOpTy>
+struct FoldMemRefReshape final : public OpConversionPattern<ReshapeOpTy> {
+  using OpConversionPattern<ReshapeOpTy>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      ReshapeOpTy op, typename ReshapeOpTy::Adaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!isRankOneMemRef(adaptor.src().getType())) {
+      return rewriter.notifyMatchFailure(
+          op, "expected converted memref of rank == 1");
+    }
+    rewriter.replaceOp(op, adaptor.src());
+    return success();
+  };
+};
+
 /// Returns the number of bytes of the given `type`. Returns llvm::None if
 /// cannot deduce.
 ///
@@ -551,15 +573,17 @@
              FlattenGlobal, FlattenGetGlobal, FlattenBindingSubspan,
              LinearizeLoadIndices, LinearizeStoreIndices,
              LinearizeTransferReadIndices, LinearizeTransferWriteIndices,
-             AdjustConversionCast>(typeConverter, &context);
+             AdjustConversionCast, FoldMemRefReshape<memref::CollapseShapeOp>,
+             FoldMemRefReshape<memref::ExpandShapeOp>>(typeConverter, &context);
 
     ConversionTarget target(context);
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
-    target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp,
-                                 memref::AllocaOp, memref::AllocOp,
-                                 memref::GetGlobalOp>([](Operation *op) {
-      return isRankOneMemRef(op->getResultTypes().front());
-    });
+    target.addDynamicallyLegalOp<
+        IREE::HAL::InterfaceBindingSubspanOp, memref::AllocaOp, memref::AllocOp,
+        memref::CollapseShapeOp, memref::ExpandShapeOp, memref::GetGlobalOp>(
+        [](Operation *op) {
+          return isRankOneMemRef(op->getResultTypes().front());
+        });
     target.addDynamicallyLegalOp<memref::GlobalOp>(
         [](memref::GlobalOp op) { return isRankOneMemRef(op.type()); });
     target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp loadOp) {
diff --git a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
index a76fca2..fdf8333 100644
--- a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
+++ b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
@@ -331,3 +331,47 @@
 //      CHECK:   %[[LOAD:.+]] = memref.load %[[SPAN0]][%[[INDEX0]]] : memref<?xf32>
 //      CHECK:   %[[INDEX1:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]]]
 //      CHECK:   memref.store %[[LOAD]], %[[SPAN1]][%[[INDEX1]]] : memref<?xf32>
+
+// -----
+
+func @collapse_shape(%offset : index, %i0 : index, %i1 : index) -> f32 {
+  %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<4x5x6x7xf32>
+  %collapse = memref.collapse_shape %subspan[[0, 1], [2, 3]] : memref<4x5x6x7xf32> into memref<20x42xf32>
+  %value = memref.load %collapse[%i0, %i1] : memref<20x42xf32>
+  return %value : f32
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+  hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+}
+
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 42 + s1 + s2 floordiv 4)>
+//      CHECK: func @collapse_shape
+// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index)
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[SIZE:.+]] = arith.constant 840 : index
+//      CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[OFFSET]]]
+//      CHECK:   memref.load %[[SUBSPAN]][%[[INDEX]]]
+
+// -----
+
+func @expand_shape(%offset : index, %i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
+  %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<20x42xf32>
+  %expand = memref.expand_shape %subspan[[0, 1], [2, 3]] : memref<20x42xf32> into memref<4x5x6x7xf32>
+  %value = memref.load %expand[%i0, %i1, %i2, %i3] : memref<4x5x6x7xf32>
+  return %value : f32
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+  hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+}
+
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 210 + s1 * 42 + s2 * 7 + s3 + s4 floordiv 4)>
+//      CHECK: func @expand_shape
+// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[SIZE:.+]] = arith.constant 840 : index
+//      CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[OFFSET]]]
+//      CHECK:   memref.load %[[SUBSPAN]][%[[INDEX]]]