[Codegen] Allow padding of dynamic allocas (#19399)
This PR adds support for padding for allocas in the
PadDynamicAllocsPass. The padding works the same for alloca as for
alloc.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
index db1819b..2a9d2d3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
@@ -65,7 +65,8 @@
return failure();
}
-static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp,
+template <typename AllocLikeOp>
+static LogicalResult padAlloc(MLIRContext *context, AllocLikeOp allocOp,
const DataFlowSolver &solver) {
IRRewriter rewriter(context);
rewriter.setInsertionPoint(allocOp);
@@ -94,7 +95,7 @@
MemRefType allocType = MemRefType::get(shape, elType, AffineMap(),
allocOp.getType().getMemorySpace());
Location loc = allocOp.getLoc();
- Value paddedAlloc = rewriter.create<memref::AllocOp>(loc, allocType);
+ Value paddedAlloc = rewriter.create<AllocLikeOp>(loc, allocType);
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
Value subview = rewriter.create<memref::SubViewOp>(loc, paddedAlloc, offsets,
@@ -111,7 +112,6 @@
void runOnOperation() override {
auto funcOp = getOperation();
MLIRContext *context = &getContext();
- SmallVector<memref::AllocOp> sharedMemAllocs;
DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
@@ -122,12 +122,21 @@
}
// Collect all the alloc operations.
- funcOp.walk(
- [&](memref::AllocOp allocOp) { sharedMemAllocs.push_back(allocOp); });
- for (memref::AllocOp alloc : sharedMemAllocs) {
+ SmallVector<memref::AllocOp> allocs;
+ funcOp.walk([&](memref::AllocOp allocOp) { allocs.push_back(allocOp); });
+ for (memref::AllocOp alloc : allocs) {
if (failed(padAlloc(context, alloc, solver)))
return signalPassFailure();
}
+
+ // Collect all the alloca operations.
+ SmallVector<memref::AllocaOp> allocas;
+ funcOp.walk(
+ [&](memref::AllocaOp allocaOp) { allocas.push_back(allocaOp); });
+ for (memref::AllocaOp alloca : allocas) {
+ if (failed(padAlloc(context, alloca, solver)))
+ return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
index 0b56bd2..e9d4d7b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
@@ -37,4 +37,14 @@
return
}
// CHECK-LABEL: func @dynamic_bound_alloc(
-// CHECK: %alloc = memref.alloc() : memref<4088xf32, 3>
+// CHECK: memref.alloc() : memref<4088xf32, 3>
+
+// -----
+
+func.func @dynamic_bound_alloca(%id : index) {
+ %0 = util.assume.int %id<umin = 0, umax = 4088> : index
+ %1 = memref.alloca(%0) : memref<?xf32, 3>
+ return
+}
+// CHECK-LABEL: func @dynamic_bound_alloca(
+// CHECK: memref.alloca() : memref<4088xf32, 3>