Teach ResolveBufferDescriptors how to handle memref.get_global. (#10459)

Also adds some missing tests for alloca.

This allows the invocation listed in the bug to compile. Just
spot-checking the code, there is a lot of work to do to make it good,
but one step at a time.

Fixes #10427
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index 1483925..55ca07e 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -184,8 +184,8 @@
   }
 };
 
-// Allocations (and anything else the returns a non-offset identity memref)
-// are matched by this pattern.
+// Allocations always return a non-offset memref and are matched by this
+// pattern.
 struct FromAllocation : public OpRewritePattern<GetBufferDescriptorOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
@@ -252,6 +252,70 @@
   }
 };
 
+// MemRef globals are always static shaped and reference a non-offset
+// buffer.
+struct FromGlobal : public OpRewritePattern<GetBufferDescriptorOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
+                                PatternRewriter &rewriter) const override {
+    auto global = op.getSource().getDefiningOp<memref::GetGlobalOp>();
+    if (!global) return failure();
+    auto memRefType = global.getResult().getType().cast<MemRefType>();
+    if (!memRefType.getLayout().isIdentity()) {
+      return rewriter.notifyMatchFailure(op, "not identity allocation");
+    }
+
+    auto loc = op.getLoc();
+    IndexSet indexSet(loc, rewriter);
+
+    // Replace the op with values:
+    //   base_buffer: The subspan result
+    //   offset: byte offset from subspan divided by element type size
+    //   sizes: static and dynamic sizes from the subspan
+    //   strides: identity strides
+    int rank = memRefType.getRank();
+
+    // Compute sizes.
+    SmallVector<Value> sizes;
+    for (int i = 0; i < rank; ++i) {
+      assert(!memRefType.isDynamicDim(i) &&
+             "memref.get_global does not support dynamic dims");
+      sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
+          loc, memRefType.getDimSize(i)));
+
+      // Replace as we go.
+      op.getSizes()[i].replaceAllUsesWith(sizes.back());
+    }
+
+    // Strides (just creates identity strides).
+    if (rank > 0) {
+      SmallVector<Value> strides;
+      strides.resize(rank);
+      strides[rank - 1] = indexSet.get(1);
+      for (int i = rank - 2; i >= 0; --i) {
+        strides[i] = rewriter.createOrFold<arith::MulIOp>(loc, strides[i + 1],
+                                                          sizes[i + 1]);
+      }
+      for (int i = 0; i < rank; ++i) {
+        op.getStrides()[i].replaceAllUsesWith(strides[i]);
+      }
+    }
+
+    // Offset.
+    op.getOffset().replaceAllUsesWith(indexSet.get(0));
+
+    // Base buffer.
+    op.getBaseBuffer().replaceAllUsesWith(
+        rewriter
+            .create<UnrealizedConversionCastOp>(
+                loc, op.getBaseBuffer().getType(), global.getResult())
+            .getResult(0));
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 class ResolveBufferDescriptorsPass
     : public ResolveBufferDescriptorsBase<ResolveBufferDescriptorsPass> {
  public:
@@ -263,7 +327,7 @@
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    patterns.insert<FromAllocation, FromHalInterfaceBindingSubspan,
+    patterns.insert<FromAllocation, FromGlobal, FromHalInterfaceBindingSubspan,
                     FromMemRefSubView>(&getContext());
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index d4fda74..96db626 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -111,3 +111,49 @@
   %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x?xindex> -> !util.buffer, index, index, index, index, index
   return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
 }
+
+// -----
+
+// CHECK-LABEL: @resolve_alloca_static
+func.func @resolve_alloca_static() -> (!util.buffer, index, index, index, index, index) {
+  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+  //     CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+  %0 = memref.alloca() : memref<512x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @resolve_alloca_dynamic
+func.func @resolve_alloca_dynamic(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
+  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+  //     CHECK: return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]]
+  %0 = memref.alloca(%arg0) : memref<?x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @resolve_global
+memref.global "private" constant @__constant_2xi32 : memref<512x384xf32> = dense<0.0>
+
+func.func @resolve_global() -> (!util.buffer, index, index, index, index, index) {
+  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
+  //     CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+  %0 = memref.get_global @__constant_2xi32 : memref<512x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}