[Codegen] Add support for memref.expand_shape to propagation util (#18202)
Similar to `memref.subview`, `memref.expand_shape` needs to have its
type updated when propagating type changes. This adds support for expand
shape to the propagation util so that passes like GPUReduceBankConflicts
can handle `memref.expand_shape`.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
index 1e9d647..befb244 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
@@ -25,6 +25,30 @@
// -----
+// CHECK-LABEL: func.func @pad_alloc_expand_shape
+// CHECK: %[[A:.*]] = memref.alloc() : memref<4x32x66xf32, #gpu.address_space<workgroup>>
+// CHECK: %[[S1:.*]] = memref.subview %[[A]][0, 0, 0] [4, 32, 64] [1, 1, 1] :
+// CHECK-SAME: memref<4x32x66xf32, #gpu.address_space<workgroup>> to memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>>
+// CHECK: %[[E:.*]] = memref.expand_shape %[[S1]] {{\[}}[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
+// CHECK-SAME: memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>> into
+// CHECK-SAME: memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>>
+// CHECK: vector.transfer_write %{{.*}}, %[[E]][%{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true]} :
+// CHECK-SAME: vector<4xf32>, memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>
+func.func @pad_alloc_expand_shape(%a: memref<1024x1024xf32>) {
+ %0 = memref.alloc() : memref<4x32x64xf32, #gpu.address_space<workgroup>>
+ %1 = memref.expand_shape %0 [[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
+ : memref<4x32x64xf32, #gpu.address_space<workgroup>> into memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} :
+ memref<1024x1024xf32>, vector<4xf32>
+ vector.transfer_write %3, %1[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true]} :
+ vector<4xf32>, memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func.func @pad_alloc_negative
// CHECK: memref.alloc(%{{.*}}) : memref<?x32x64xf32, #gpu.address_space<workgroup>
func.func @pad_alloc_negative(%a: memref<1024x1024xf32>, %i: index, %v: vector<4xf32>) {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 18ba760..2812fc4 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -988,8 +988,30 @@
newSubviewOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
- return SmallVector<Value>(newSubviewOp->result_begin(),
- newSubviewOp->result_end());
+ return llvm::to_vector_of<Value>(newSubviewOp->getResults());
+ }
+ if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(user)) {
+ auto currResultType =
+ llvm::cast<MemRefType>(expandOp.getResult().getType());
+ auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
+
+ FailureOr<MemRefType> newResultType =
+ memref::ExpandShapeOp::computeExpandedType(
+ newSourceType, currResultType.getShape(),
+ expandOp.getReassociationIndices());
+ if (failed(newResultType)) {
+ return std::nullopt;
+ }
+
+ auto newExpandOp = rewriter.create<memref::ExpandShapeOp>(
+ loc, *newResultType, replacement, expandOp.getReassociation(),
+ expandOp.getOutputShape(), expandOp.getStaticOutputShape());
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t\tNew user : ";
+ newExpandOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
+ llvm::dbgs() << "\n";
+ });
+ return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
return std::nullopt;
}