Teach ResolveBufferDescriptors how to handle memref.get_global. (#10459)
Also adds some missing tests for alloca.
This allows the invocation listed in the bug to compile. Just
spot-checking the code, there is a lot of work to do to make it good,
but one step at a time.
Fixes #10427
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index 1483925..55ca07e 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -184,8 +184,8 @@
}
};
-// Allocations (and anything else the returns a non-offset identity memref)
-// are matched by this pattern.
+// Allocations always return a non-offset memref and are matched by this
+// pattern.
struct FromAllocation : public OpRewritePattern<GetBufferDescriptorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
@@ -252,6 +252,70 @@
}
};
+// MemRef globals are always static shaped and reference a non-offset
+// buffer.
+struct FromGlobal : public OpRewritePattern<GetBufferDescriptorOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
+ PatternRewriter &rewriter) const override {
+ auto global = op.getSource().getDefiningOp<memref::GetGlobalOp>();
+ if (!global) return failure();
+ auto memRefType = global.getResult().getType().cast<MemRefType>();
+ if (!memRefType.getLayout().isIdentity()) {
+ return rewriter.notifyMatchFailure(op, "not identity allocation");
+ }
+
+ auto loc = op.getLoc();
+ IndexSet indexSet(loc, rewriter);
+
+ // Replace the op with values:
+ // base_buffer: The subspan result
+ // offset: byte offset from subspan divided by element type size
+ // sizes: static and dynamic sizes from the subspan
+ // strides: identity strides
+ int rank = memRefType.getRank();
+
+ // Compute sizes.
+ SmallVector<Value> sizes;
+ for (int i = 0; i < rank; ++i) {
+ assert(!memRefType.isDynamicDim(i) &&
+ "memref.get_global does not support dynamic dims");
+ sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, memRefType.getDimSize(i)));
+
+ // Replace as we go.
+ op.getSizes()[i].replaceAllUsesWith(sizes.back());
+ }
+
+ // Strides (just creates identity strides).
+ if (rank > 0) {
+ SmallVector<Value> strides;
+ strides.resize(rank);
+ strides[rank - 1] = indexSet.get(1);
+ for (int i = rank - 2; i >= 0; --i) {
+ strides[i] = rewriter.createOrFold<arith::MulIOp>(loc, strides[i + 1],
+ sizes[i + 1]);
+ }
+ for (int i = 0; i < rank; ++i) {
+ op.getStrides()[i].replaceAllUsesWith(strides[i]);
+ }
+ }
+
+ // Offset.
+ op.getOffset().replaceAllUsesWith(indexSet.get(0));
+
+ // Base buffer.
+ op.getBaseBuffer().replaceAllUsesWith(
+ rewriter
+ .create<UnrealizedConversionCastOp>(
+ loc, op.getBaseBuffer().getType(), global.getResult())
+ .getResult(0));
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
class ResolveBufferDescriptorsPass
: public ResolveBufferDescriptorsBase<ResolveBufferDescriptorsPass> {
public:
@@ -263,7 +327,7 @@
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- patterns.insert<FromAllocation, FromHalInterfaceBindingSubspan,
+ patterns.insert<FromAllocation, FromGlobal, FromHalInterfaceBindingSubspan,
FromMemRefSubView>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(),
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index d4fda74..96db626 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -111,3 +111,49 @@
%base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x?xindex> -> !util.buffer, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
}
+
+// -----
+
+// CHECK-LABEL: @resolve_alloca_static
+func.func @resolve_alloca_static() -> (!util.buffer, index, index, index, index, index) {
+ // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+ %0 = memref.alloca() : memref<512x384xf32>
+ %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+ return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @resolve_alloca_dynamic
+func.func @resolve_alloca_dynamic(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
+ // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]]
+ %0 = memref.alloca(%arg0) : memref<?x384xf32>
+ %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x384xf32> -> !util.buffer, index, index, index, index, index
+ return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @resolve_global
+memref.global "private" constant @__constant_2xi32 : memref<512x384xf32> = dense<0.0>
+
+func.func @resolve_global() -> (!util.buffer, index, index, index, index, index) {
+ // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+ %0 = memref.get_global @__constant_2xi32 : memref<512x384xf32>
+ %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+ return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}