Track hal.allocators through hal.buffer.subspan.
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 969de1b..882cc86 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -160,7 +160,7 @@
namespace {
-/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
+/// Skips a hal.buffer.allocator accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferAllocatorOp : public OpRewritePattern<BufferAllocatorOp> {
using OpRewritePattern<BufferAllocatorOp>::OpRewritePattern;
@@ -175,6 +175,11 @@
op.buffer().getDefiningOp())) {
rewriter.replaceOp(op, allocateOp.allocator());
return success();
+ } else if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ op.buffer().getDefiningOp())) {
+ rewriter.replaceOpWithNewOp<BufferAllocatorOp>(op,
+ subspanOp.source_buffer());
+ return success();
}
return failure();
}
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
index 656ad7a..257df6d 100644
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
@@ -12,3 +12,19 @@
// CHECK: return %[[AL]]
return %1 : !hal.allocator
}
+
+// -----
+
+// CHECK-LABEL: @skip_subspan_buffer_allocator
+func @skip_subspan_buffer_allocator() -> !hal.allocator {
+ %c0 = constant 0 : index
+ %c184 = constant 184 : index
+ %c384 = constant 384 : index
+ // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator"
+ %allocator = "test_hal.allocator"() : () -> !hal.allocator
+ %source_buffer = hal.allocator.allocate %allocator, "HostVisible|HostCoherent", "Transfer", %c384 : !hal.buffer
+ %span_buffer = hal.buffer.subspan %source_buffer, %c0, %c184 : !hal.buffer
+ %1 = hal.buffer.allocator %span_buffer : !hal.allocator
+ // CHECK: return %[[AL]]
+ return %1 : !hal.allocator
+}