Update `ResolveBufferDescriptors` to handle `memref.extract_strided_metadata` (#12205)

The pass currently is meant to handle `vmvx.get_buffer_descriptors`. This operation is very similar to the `memref.extract_strided_memref` and the logic could be re-used to handle this operation as well. 
This allows using the ukernel path being added to IREE which is intended to work for both VMVX
and LLVM CPU codegeneration paths.

The only pattern that is a bit more harder to reuse is

```
%0 = hal.interface.binding.subspan
.. = vmvx.get_buffer_descriptor %0
```

since the base buffer used by the `vmvx.get_buffer_descriptor` isn't the same type as `memref.extract_strided_memref`. That pattern is not adapted yet.
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD
index 47ad0ff..26cb330 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/BUILD
@@ -73,6 +73,7 @@
         "@llvm-project//mlir:AffineTransforms",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:ArithTransforms",
+        "@llvm-project//mlir:ArithUtils",
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FuncTransforms",
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
index 254fab3..5c15c48 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/CMakeLists.txt
@@ -52,6 +52,7 @@
     MLIRAffineTransforms
     MLIRArithDialect
     MLIRArithTransforms
+    MLIRArithUtils
     MLIRBufferizationDialect
     MLIRFuncDialect
     MLIRFuncTransforms
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index 3e59503..a75a6e4 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -8,7 +8,10 @@
 #include "iree/compiler/Dialect/VMVX/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/VMVX/Transforms/Passes.h"
 #include "iree/compiler/Utils/IndexSet.h"
+#include "llvm/Support/MathExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -17,20 +20,186 @@
 namespace mlir::iree_compiler::IREE::VMVX {
 
 namespace {
+/// Helper struct to return the offset, sizes and strides
+/// of a `source` of a `memref.extract_strided_metadata` op.
+struct DescriptorInfo {
+  OpFoldResult offset;
+  SmallVector<OpFoldResult> sizes;
+  SmallVector<OpFoldResult> strides;
+};
+}  // namespace
 
-struct FromMemRefSubView : public OpRewritePattern<GetBufferDescriptorOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
+/// Returns an AffineMap for an add or a mul.
+static AffineMap getAddMap(MLIRContext *context) {
+  AffineExpr s0, s1;
+  bindSymbols(context, s0, s1);
+  return AffineMap::get(0, 2, s0 + s1);
+}
+static AffineMap getMulMap(MLIRContext *context) {
+  AffineExpr s0, s1;
+  bindSymbols(context, s0, s1);
+  return AffineMap::get(0, 2, s0 * s1);
+}
+
+static FailureOr<DescriptorInfo> resolveBufferDescriptorForSubview(
+    memref::SubViewOp subview, RewriterBase &rewriter, Location loc,
+    Value sourceOffset, ValueRange sourceSizes, ValueRange sourceStrides) {
+  DescriptorInfo resultDescriptor;
+
+  // For sizes, we just use the new ones.
+  resultDescriptor.sizes = subview.getMixedSizes();
+
+  // Apply stride multipliers.
+  AffineMap mulMap = getMulMap(rewriter.getContext());
+  for (auto [index, stride] : llvm::enumerate(subview.getMixedStrides())) {
+    OpFoldResult currentStride = makeComposedFoldedAffineApply(
+        rewriter, loc, mulMap, {sourceStrides[index], stride});
+    resultDescriptor.strides.push_back(currentStride);
+  }
+
+  // Offsets.
+  resultDescriptor.offset = sourceOffset;
+  AffineMap addMap = getAddMap(rewriter.getContext());
+  for (auto [index, offset] : llvm::enumerate(subview.getMixedOffsets())) {
+    OpFoldResult physicalOffset = makeComposedFoldedAffineApply(
+        rewriter, loc, mulMap, {offset, resultDescriptor.strides[index]});
+    resultDescriptor.offset = makeComposedFoldedAffineApply(
+        rewriter, loc, addMap, {resultDescriptor.offset, physicalOffset});
+  }
+  return resultDescriptor;
+}
+
+/// Returns the strides based on the sizes assuming that the `memref`
+/// has default layout, i.e. it is not a result of a subview.
+static SmallVector<OpFoldResult> getStridesFromSizes(
+    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> sizes) {
+  if (sizes.size() == 0) {
+    return {};
+  }
+  SmallVector<OpFoldResult> strides(sizes.size());
+  strides.back() = rewriter.getIndexAttr(1);
+  if (sizes.size() == 1) {
+    return strides;
+  }
+  AffineMap mulMap = getMulMap(rewriter.getContext());
+  for (int i = sizes.size() - 2; i >= 0; --i) {
+    strides[i] = makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                               {strides[i + 1], sizes[i + 1]});
+  }
+  return strides;
+}
+
+static FailureOr<DescriptorInfo> resolveBufferDescriptorForAllocation(
+    memref::AllocaOp alloca, RewriterBase &rewriter, Location loc) {
+  DescriptorInfo resultDescriptor;
+
+  // Replace the op with values:
+  //   base_buffer: The subspan result
+  //   offset: byte offset from subspan divided by element type size
+  //   sizes: static and dynamic sizes from the subspan
+  //   strides: identity strides
+  auto memRefType = alloca.getResult().getType().cast<MemRefType>();
+  int rank = memRefType.getRank();
+
+  // Compute sizes.
+  auto dynamicDimIt = alloca.getDynamicSizes().begin();
+  for (int i = 0; i < rank; ++i) {
+    if (memRefType.isDynamicDim(i)) {
+      resultDescriptor.sizes.push_back(*dynamicDimIt);
+      dynamicDimIt++;
+    } else {
+      resultDescriptor.sizes.push_back(
+          rewriter.getIndexAttr(memRefType.getDimSize(i)));
+    }
+  }
+
+  // Strides (just creates identity strides).
+  resultDescriptor.strides =
+      getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
+
+  resultDescriptor.offset = rewriter.getIndexAttr(0);
+  return resultDescriptor;
+}
+
+static FailureOr<DescriptorInfo> resolveBufferDescriptorForGetGlobalOp(
+    memref::GetGlobalOp global, RewriterBase &rewriter, Location loc) {
+  IndexSet indexSet(loc, rewriter);
+  DescriptorInfo resultDescriptor;
+
+  // Replace the op with values:
+  //   base_buffer: The subspan result
+  //   offset: byte offset from subspan divided by element type size
+  //   sizes: static and dynamic sizes from the subspan
+  //   strides: identity strides
+  auto memRefType = global.getResult().getType().cast<MemRefType>();
+  int rank = memRefType.getRank();
+
+  // Compute sizes.
+  for (int i = 0; i < rank; ++i) {
+    if (memRefType.isDynamicDim(i)) {
+      return rewriter.notifyMatchFailure(
+          global, "memref.get_global does not support dynamic dims");
+    }
+    resultDescriptor.sizes.push_back(
+        rewriter.getIndexAttr(memRefType.getDimSize(i)));
+  }
+
+  // Strides (just creates identity strides).
+  resultDescriptor.strides =
+      getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
+
+  // Offset.
+  resultDescriptor.offset = rewriter.getIndexAttr(0);
+  return resultDescriptor;
+}
+
+/// Replaces the offsets, sizes and strides based on values provided
+/// by `DescriptorInfo` object.
+template <typename OpTy>
+static void replaceOffsetSizesAndStridesWith(
+    RewriterBase &rewriter, OpTy op, const DescriptorInfo &resultDescriptor) {
+  int rank = resultDescriptor.sizes.size();
+  assert(rank == resultDescriptor.strides.size() &&
+         "expected number of sizes and strides to match");
+  assert(op.getSizes().size() == rank &&
+         "expected as many size replacements as the number of sizes in the "
+         "original operation");
+  assert(op.getStrides().size() == rank &&
+         "expected as many strides replacements as the number of strides in "
+         "the original operation");
+  Location loc = op.getLoc();
+  for (int i = 0; i < rank; ++i) {
+    // Sizes
+    rewriter.replaceAllUsesWith(op.getSizes()[i],
+                                getValueOrCreateConstantIndexOp(
+                                    rewriter, loc, resultDescriptor.sizes[i]));
+    // Strides
+    rewriter.replaceAllUsesWith(
+        op.getStrides()[i], getValueOrCreateConstantIndexOp(
+                                rewriter, loc, resultDescriptor.strides[i]));
+  }
+  // Offset
+  rewriter.replaceAllUsesWith(
+      op.getOffset(),
+      getValueOrCreateConstantIndexOp(rewriter, loc, resultDescriptor.offset));
+}
+
+namespace {
+
+template <typename OpTy>
+struct FromMemRefSubView : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    auto subview = op.getSource().getDefiningOp<memref::SubViewOp>();
+    auto subview = op.getSource().template getDefiningOp<memref::SubViewOp>();
     if (!subview) return failure();
     auto loc = op.getLoc();
     IndexSet indexSet(loc, rewriter);
 
     // Get types.
-    auto subType = subview.getResult().getType().cast<MemRefType>();
+    auto subType = subview.getResult().getType().template cast<MemRefType>();
     Value source = subview.getSource();
-    auto sourceType = source.getType().cast<MemRefType>();
+    auto sourceType = source.getType().template cast<MemRefType>();
     int sourceRank = sourceType.getRank();
     int subRank = subType.getRank();
     (void)subRank;
@@ -41,75 +210,40 @@
     for (int i = 0; i < sourceRank; i++) {
       sizeStrideTypes.push_back(indexType);
     }
-    auto sourceDesc = rewriter.create<IREE::VMVX::GetBufferDescriptorOp>(
-        loc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
-        sizeStrideTypes, subview.getSource());
+    auto sourceDesc =
+        rewriter.create<OpTy>(loc, op.getBaseBuffer().getType(), indexType,
+                              sizeStrideTypes, sizeStrideTypes, source);
 
-    // For sizes, we just use the new ones.
+    FailureOr<DescriptorInfo> resultDescriptor =
+        resolveBufferDescriptorForSubview(
+            subview, rewriter, loc, sourceDesc.getOffset(),
+            sourceDesc.getSizes(), sourceDesc.getStrides());
+
+    if (failed(resultDescriptor)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve descriptor with source being a subview op");
+    }
+
     llvm::SmallBitVector droppedDims = subview.getDroppedDims();
-    unsigned insertedDims = 0;
-    SmallVector<Value> newSizes;
+    int targetIndex = 0;
     for (int i = 0; i < sourceRank; ++i) {
-      // Skip the sizes that don't show up in the final type.
       if (droppedDims.test(i)) continue;
-
-      if (subview.isDynamicSize(i)) {
-        newSizes.push_back(subview.getDynamicSize(i));
-      } else {
-        newSizes.push_back(indexSet.get(subview.getStaticSize(i)));
-      }
-      op.getSizes()[insertedDims++].replaceAllUsesWith(newSizes.back());
+      rewriter.replaceAllUsesWith(
+          op.getSizes()[targetIndex],
+          getValueOrCreateConstantIndexOp(rewriter, loc,
+                                          resultDescriptor->sizes[i]));
+      rewriter.replaceAllUsesWith(
+          op.getStrides()[targetIndex],
+          getValueOrCreateConstantIndexOp(rewriter, loc,
+                                          resultDescriptor->strides[i]));
+      targetIndex++;
     }
-    assert(insertedDims == subRank &&
-           "Should have populated all the non-reduced sizes");
-
-    // Apply stride multipliers.
-    SmallVector<Value> strides;
-    insertedDims = 0;
-    for (int i = 0; i < sourceRank; ++i) {
-      Value currentStride;
-      if (subview.isDynamicStride(i)) {
-        currentStride = subview.getDynamicStride(i);
-      } else {
-        currentStride = indexSet.get(subview.getStaticStride(i));
-      }
-      currentStride = rewriter.createOrFold<arith::MulIOp>(
-          loc, sourceDesc.getStrides()[i], currentStride);
-      strides.push_back(currentStride);
-
-      // Don't replace the value of dropped dimensions.
-      // Although the new stride will be used in the computation of the final
-      // offset, there's no value to replace.
-      if (droppedDims.test(i)) continue;
-
-      op.getStrides()[insertedDims++].replaceAllUsesWith(currentStride);
-    }
-    assert(insertedDims == subRank &&
-           "Should have populated all the non-reduced strides");
-
-    // Offsets.
-    Value offset = sourceDesc.getOffset();
-    for (int i = 0; i < sourceRank; ++i) {
-      Value logicalOffset;
-      if (subview.isDynamicOffset(i)) {
-        logicalOffset = subview.getDynamicOffset(i);
-      } else {
-        int64_t staticOffset = subview.getStaticOffset(i);
-        if (staticOffset == 0) {
-          // Since added, just omit (will multiply to 0).
-          continue;
-        }
-        logicalOffset = indexSet.get(staticOffset);
-      }
-      Value physicalOffset =
-          rewriter.createOrFold<arith::MulIOp>(loc, logicalOffset, strides[i]);
-      offset =
-          rewriter.createOrFold<arith::AddIOp>(loc, offset, physicalOffset);
-    }
-    op.getOffset().replaceAllUsesWith(offset);
+    rewriter.replaceAllUsesWith(op.getOffset(),
+                                getValueOrCreateConstantIndexOp(
+                                    rewriter, loc, resultDescriptor->offset));
 
     // Base.
-    op.getBaseBuffer().replaceAllUsesWith(sourceDesc.getBaseBuffer());
+    rewriter.replaceAllUsesWith(op.getBaseBuffer(), sourceDesc.getBaseBuffer());
     rewriter.eraseOp(op);
     return success();
   }
@@ -117,62 +251,47 @@
 
 struct FromHalInterfaceBindingSubspan
     : public OpRewritePattern<GetBufferDescriptorOp> {
-  using OpRewritePattern::OpRewritePattern;
+  using OpRewritePattern<GetBufferDescriptorOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
                                 PatternRewriter &rewriter) const override {
     auto binding =
-        op.getSource().getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+        op.getSource()
+            .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
     if (!binding) return failure();
 
     auto loc = op.getLoc();
-    IndexSet indexSet(loc, rewriter);
 
-    // Replace the op with values:
-    //   base_buffer: The subspan result
-    //   offset: byte offset from subspan divided by element type size
-    //   sizes: static and dynamic sizes from the subspan
-    //   strides: identity strides
     auto memRefType = binding.getResult().getType().cast<MemRefType>();
     int rank = memRefType.getRank();
+    DescriptorInfo resultDescriptor;
 
     // Compute sizes.
-    SmallVector<Value> sizes;
     auto dynamicDimIt = binding.getDynamicDims().begin();
     for (int i = 0; i < rank; ++i) {
       if (memRefType.isDynamicDim(i)) {
-        sizes.push_back(*dynamicDimIt);
+        resultDescriptor.sizes.push_back(*dynamicDimIt);
         dynamicDimIt++;
       } else {
-        sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
-            loc, memRefType.getDimSize(i)));
+        resultDescriptor.sizes.push_back(
+            rewriter.getIndexAttr(memRefType.getDimSize(i)));
       }
-
-      // Replace as we go.
-      op.getSizes()[i].replaceAllUsesWith(sizes.back());
     }
 
     // Strides.
-    if (rank > 0) {
-      SmallVector<Value> strides;
-      strides.resize(rank);
-      strides[rank - 1] = indexSet.get(1);
-      for (int i = rank - 2; i >= 0; --i) {
-        strides[i] = rewriter.createOrFold<arith::MulIOp>(loc, strides[i + 1],
-                                                          sizes[i + 1]);
-      }
-      for (int i = 0; i < rank; ++i) {
-        op.getStrides()[i].replaceAllUsesWith(strides[i]);
-      }
-    }
+    resultDescriptor.strides =
+        getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
 
     // Offset.
     auto elementSize =
         rewriter.create<IREE::Util::SizeOfOp>(loc, memRefType.getElementType());
-    op.getOffset().replaceAllUsesWith(rewriter.createOrFold<arith::DivUIOp>(
-        loc, binding.getByteOffset(), elementSize));
+    resultDescriptor.offset = rewriter.createOrFold<arith::DivUIOp>(
+        loc, binding.getByteOffset(), elementSize);
+
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
 
     // Base buffer.
-    op.getBaseBuffer().replaceAllUsesWith(
+    rewriter.replaceAllUsesWith(
+        op.getBaseBuffer(),
         rewriter
             .create<IREE::VMVX::GetRawInterfaceBindingBufferOp>(
                 loc, op.getBaseBuffer().getType(), binding.getSetAttr(),
@@ -184,66 +303,95 @@
   }
 };
 
+struct ResolveExtractMetadataFromHalInterfaceBindingSubspan
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto binding =
+        op.getSource()
+            .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+    if (!binding) return failure();
+    auto memRefType = binding.getResult().getType().template cast<MemRefType>();
+    if (memRefType.getRank() < 1) return failure();
+
+    auto loc = op.getLoc();
+    int rank = memRefType.getRank();
+    DescriptorInfo resultDescriptor;
+
+    // Compute sizes.
+    auto dynamicDimIt = binding.getDynamicDims().begin();
+    for (int i = 0; i < rank; ++i) {
+      if (memRefType.isDynamicDim(i)) {
+        resultDescriptor.sizes.push_back(*dynamicDimIt);
+        dynamicDimIt++;
+      } else {
+        resultDescriptor.sizes.push_back(
+            rewriter.getIndexAttr(memRefType.getDimSize(i)));
+      }
+    }
+
+    // Strides.
+    resultDescriptor.strides =
+        getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
+    resultDescriptor.offset = rewriter.getIndexAttr(0);
+
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
+
+    // Base buffer. Use a 1D memref for hal.interface.binding.subspan.
+    AffineMap mulMap = getMulMap(rewriter.getContext());
+    OpFoldResult linearizedMemrefSize = rewriter.getIndexAttr(1);
+    for (auto size : resultDescriptor.sizes) {
+      linearizedMemrefSize = makeComposedFoldedAffineApply(
+          rewriter, loc, mulMap, {linearizedMemrefSize, size});
+    }
+    SmallVector<int64_t> staticLinearShape;
+    SmallVector<Value> dynamicLinearShape;
+    dispatchIndexOpFoldResult(linearizedMemrefSize, dynamicLinearShape,
+                              staticLinearShape);
+    Value linearInterfaceBinding =
+        rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+            loc, op.getBaseBuffer().getType(), binding.getSetAttr(),
+            binding.getBindingAttr(), binding.getDescriptorTypeAttr(),
+            binding.getByteOffset(),
+            /*dynamicDims =*/ValueRange{}, binding.getAlignmentAttr(),
+            binding.getDescriptorFlagsAttr());
+    rewriter.replaceAllUsesWith(op.getBaseBuffer(), linearInterfaceBinding);
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 // Allocations always return a non-offset memref and are matched by this
 // pattern.
-struct FromAllocation : public OpRewritePattern<GetBufferDescriptorOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
+template <typename OpTy>
+struct FromAllocation : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    auto alloca = op.getSource().getDefiningOp<memref::AllocaOp>();
+    auto alloca = op.getSource().template getDefiningOp<memref::AllocaOp>();
     if (!alloca) return failure();
-    auto memRefType = alloca.getResult().getType().cast<MemRefType>();
+    auto memRefType = alloca.getResult().getType().template cast<MemRefType>();
     if (!memRefType.getLayout().isIdentity()) {
       return rewriter.notifyMatchFailure(op, "not identity allocation");
     }
 
     auto loc = op.getLoc();
-    IndexSet indexSet(loc, rewriter);
-
-    // Replace the op with values:
-    //   base_buffer: The subspan result
-    //   offset: byte offset from subspan divided by element type size
-    //   sizes: static and dynamic sizes from the subspan
-    //   strides: identity strides
-    int rank = memRefType.getRank();
-
-    // Compute sizes.
-    SmallVector<Value> sizes;
-    auto dynamicDimIt = alloca.getDynamicSizes().begin();
-    for (int i = 0; i < rank; ++i) {
-      if (memRefType.isDynamicDim(i)) {
-        sizes.push_back(*dynamicDimIt);
-        dynamicDimIt++;
-      } else {
-        sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
-            loc, memRefType.getDimSize(i)));
-      }
-
-      // Replace as we go.
-      op.getSizes()[i].replaceAllUsesWith(sizes.back());
+    FailureOr<DescriptorInfo> resultDescriptor =
+        resolveBufferDescriptorForAllocation(alloca, rewriter, loc);
+    if (failed(resultDescriptor)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve descriptor for memref.alloca op");
     }
 
-    // Strides (just creates identity strides).
-    if (rank > 0) {
-      SmallVector<Value> strides;
-      strides.resize(rank);
-      strides[rank - 1] = indexSet.get(1);
-      for (int i = rank - 2; i >= 0; --i) {
-        strides[i] = rewriter.createOrFold<arith::MulIOp>(loc, strides[i + 1],
-                                                          sizes[i + 1]);
-      }
-      for (int i = 0; i < rank; ++i) {
-        op.getStrides()[i].replaceAllUsesWith(strides[i]);
-      }
-    }
-
-    // Offset.
-    op.getOffset().replaceAllUsesWith(indexSet.get(0));
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
 
     // Base buffer.
-    op.getBaseBuffer().replaceAllUsesWith(
+    rewriter.replaceAllUsesWith(
+        op.getBaseBuffer(),
         rewriter
-            .create<UnrealizedConversionCastOp>(
+            .template create<UnrealizedConversionCastOp>(
                 loc, op.getBaseBuffer().getType(), alloca.getResult())
             .getResult(0));
 
@@ -254,60 +402,33 @@
 
 // MemRef globals are always static shaped and reference a non-offset
 // buffer.
-struct FromGlobal : public OpRewritePattern<GetBufferDescriptorOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
+template <typename OpTy>
+struct FromGlobal : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    auto global = op.getSource().getDefiningOp<memref::GetGlobalOp>();
+    auto global = op.getSource().template getDefiningOp<memref::GetGlobalOp>();
     if (!global) return failure();
-    auto memRefType = global.getResult().getType().cast<MemRefType>();
+    auto memRefType = global.getResult().getType().template cast<MemRefType>();
     if (!memRefType.getLayout().isIdentity()) {
       return rewriter.notifyMatchFailure(op, "not identity allocation");
     }
 
     auto loc = op.getLoc();
-    IndexSet indexSet(loc, rewriter);
-
-    // Replace the op with values:
-    //   base_buffer: The subspan result
-    //   offset: byte offset from subspan divided by element type size
-    //   sizes: static and dynamic sizes from the subspan
-    //   strides: identity strides
-    int rank = memRefType.getRank();
-
-    // Compute sizes.
-    SmallVector<Value> sizes;
-    for (int i = 0; i < rank; ++i) {
-      assert(!memRefType.isDynamicDim(i) &&
-             "memref.get_global does not support dynamic dims");
-      sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
-          loc, memRefType.getDimSize(i)));
-
-      // Replace as we go.
-      op.getSizes()[i].replaceAllUsesWith(sizes.back());
+    FailureOr<DescriptorInfo> resultDescriptor =
+        resolveBufferDescriptorForGetGlobalOp(global, rewriter, loc);
+    if (failed(resultDescriptor)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve descriptor for memref.get_global source");
     }
 
-    // Strides (just creates identity strides).
-    if (rank > 0) {
-      SmallVector<Value> strides;
-      strides.resize(rank);
-      strides[rank - 1] = indexSet.get(1);
-      for (int i = rank - 2; i >= 0; --i) {
-        strides[i] = rewriter.createOrFold<arith::MulIOp>(loc, strides[i + 1],
-                                                          sizes[i + 1]);
-      }
-      for (int i = 0; i < rank; ++i) {
-        op.getStrides()[i].replaceAllUsesWith(strides[i]);
-      }
-    }
-
-    // Offset.
-    op.getOffset().replaceAllUsesWith(indexSet.get(0));
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
 
     // Base buffer.
-    op.getBaseBuffer().replaceAllUsesWith(
+    rewriter.replaceAllUsesWith(
+        op.getBaseBuffer(),
         rewriter
-            .create<UnrealizedConversionCastOp>(
+            .template create<UnrealizedConversionCastOp>(
                 loc, op.getBaseBuffer().getType(), global.getResult())
             .getResult(0));
 
@@ -322,13 +443,20 @@
   ResolveBufferDescriptorsPass() = default;
   ResolveBufferDescriptorsPass(const ResolveBufferDescriptorsPass &) {}
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<IREE::VMVX::VMVXDialect>();
+    registry.insert<AffineDialect, IREE::VMVX::VMVXDialect>();
   }
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    patterns.insert<FromAllocation, FromGlobal, FromHalInterfaceBindingSubspan,
-                    FromMemRefSubView>(&getContext());
+    patterns.insert<FromAllocation<GetBufferDescriptorOp>,
+                    FromAllocation<memref::ExtractStridedMetadataOp>,
+                    FromGlobal<GetBufferDescriptorOp>,
+                    FromGlobal<memref::ExtractStridedMetadataOp>,
+                    FromHalInterfaceBindingSubspan,
+                    FromMemRefSubView<GetBufferDescriptorOp>,
+                    FromMemRefSubView<memref::ExtractStridedMetadataOp>,
+                    ResolveExtractMetadataFromHalInterfaceBindingSubspan>(
+        &getContext());
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
@@ -338,8 +466,10 @@
     // If any get_buffer_descriptor patterns remain, we fail.
     if (!allowUnresolved) {
       SmallVector<Operation *> remaining;
-      getOperation()->walk([&](IREE::VMVX::GetBufferDescriptorOp op) {
-        remaining.push_back(op);
+      getOperation()->walk([&](Operation *op) {
+        if (isa<GetBufferDescriptorOp, memref::ExtractStridedMetadataOp>(op)) {
+          remaining.push_back(op);
+        }
       });
 
       if (!remaining.empty()) {
@@ -355,7 +485,8 @@
 
   Option<bool> allowUnresolved{
       *this, "allow-unresolved",
-      llvm::cl::desc("Allow unresolved descriptors (for testing)")};
+      llvm::cl::desc("Allow unresolved descriptors (for testing)"),
+      llvm::cl::init(false)};
 };
 
 }  // namespace
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index ed49efe..8dc786e 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -2,39 +2,186 @@
 // RUN:   --iree-vmvx-resolve-buffer-descriptors="allow-unresolved=true" \
 // RUN:   --canonicalize %s | FileCheck %s
 
-// CHECK-LABEL: @resolve_subview
 #map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
-module @resolve_subview{
-  func.func @f(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) {
-    // CHECK-DAG: %[[BASE_BUFFER:.*]], %[[BASE_OFFSET:.*]], %[[BASE_SIZES:.*]]:2, %[[BASE_STRIDES:.*]]:2 = vmvx.get_buffer_descriptor %arg0
-    // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
-    // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index
-    // CHECK-DAG: %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %0 : index
-    // CHECK-DAG: %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index
-    // CHECK-DAG: %[[SUB_OFFSET:.*]] = arith.addi %[[I1]], %[[I2]] : index
-    //     CHECK: return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C64]], %[[C64]], %[[BASE_STRIDES]]#0, %[[BASE_STRIDES]]#1
+  func.func @resolve_subview(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) {
     %0 = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : memref<384x128xf32> to memref<64x64xf32, #map0>
     %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<64x64xf32, #map0> -> !util.buffer, index, index, index, index, index
     return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
   }
-}
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 + s1 * s2 + s3 * s4)>       
+//     CHECK: func @resolve_subview(
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:2, %[[BASE_STRIDES:.+]]:2 = vmvx.get_buffer_descriptor %arg0
+// CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%[[BASE_OFFSET]], %arg1, %[[BASE_STRIDES]]#0, %arg2, %[[BASE_STRIDES]]#1]
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C64]], %[[C64]], %[[BASE_STRIDES]]#0, %[[BASE_STRIDES]]#1
 
 // -----
 
-// CHECK-LABEL: @resolve_subview_rankreducing
 #map0 = affine_map<(d0)[s0] -> (d0 * 128 + s0)>
 func.func @resolve_subview_rankreducing(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index) {
-  // CHECK-DAG: %[[BASE_BUFFER:.*]], %[[BASE_OFFSET:.*]], %[[BASE_SIZES:.*]]:2, %[[BASE_STRIDES:.*]]:2 = vmvx.get_buffer_descriptor %arg0
-  // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
-  // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index
-  // CHECK:     %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index
-  // CHECK:     %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index
-  // CHECK:     %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index
-  // CHECK:     return %[[BASE_BUFFER]], %[[I3]], %[[C64]], %[[BASE_STRIDES]]#0
   %0 = memref.subview %arg0[%arg1, %arg2] [64, 1] [1, 1] : memref<384x128xf32> to memref<64xf32, #map0>
   %base_buffer, %offset, %size, %stride = vmvx.get_buffer_descriptor %0 : memref<64xf32, #map0> -> !util.buffer, index, index, index
   return %base_buffer, %offset, %size, %stride : !util.buffer, index, index, index
 }
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 + s1 * s2 + s3 * s4)
+//     CHECK: @resolve_subview_rankreducing(
+// CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:2, %[[BASE_STRIDES:.+]]:2 = vmvx.get_buffer_descriptor %arg0
+//     CHECK:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%[[BASE_OFFSET]], %arg1, %[[BASE_STRIDES]]#0, %arg2, %[[BASE_STRIDES]]#1]
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C64]], %[[BASE_STRIDES]]#0
+
+// -----
+
+// Check that we properly resolve subview with rankreducing when the dropped
+// rank is not the last one.
+// Orig strides: [%strides#0, %strides#1, %strides#2]
+// Sub strides: [1, 1, 1]
+// => New strides: [%strides#0, %strides#1, %strides#2]
+// Final strides == filterOutReducedDim(new strides, 0) == [%strides#1 , %strides#2]
+//
+// Orig offset: %offset
+// Sub offsets: [%arg1, %arg2, 0]
+// => Final offset: %arg1 * %strides#0 + %arg2 * %strides#1 + 0 * %strides#2 + %offset
+//
+// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
+
+func.func @resolve_subview_rankreducing_not_at_the_end(%arg0: memref<8x16x4xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) {
+  %0 = memref.subview %arg0[%arg1, %arg2, 0] [1, 6, 3] [1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4,1], offset : ?>>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<6x3xf32, strided<[4,1], offset : ?>> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 + s1 * s2 + s3 * s4)>
+//     CHECK: func @resolve_subview_rankreducing_not_at_the_end(
+// CHECK-DAG:   %[[C6:.+]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:3, %[[BASE_STRIDES:.+]]:3 = vmvx.get_buffer_descriptor %arg0
+//     CHECK:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%[[BASE_OFFSET]], %arg1, %[[BASE_STRIDES]]#0, %arg2, %[[BASE_STRIDES]]#1]
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C6]], %[[C3]], %[[BASE_STRIDES]]#1, %[[BASE_STRIDES]]#2
+
+// -----
+
+func.func @resolve_binding_subspan_zero_offset() -> (!util.buffer, index, index, index, index, index) {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<512x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_binding_subspan_zero_offset(
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[CAST:.+]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
+//     CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+
+// -----
+
+func.func @resolve_binding_subspan_offset_index(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<512x384xindex>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xindex> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_binding_subspan_offset_index(
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[INDEX_SIZE:.+]] = util.sizeof index
+// CHECK-DAG:   %[[OFFSET:.+]] = arith.divui %arg0, %[[INDEX_SIZE]] : index
+//     CHECK:   %[[CAST:.+]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
+//     CHECK:   return %[[CAST]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+
+// -----
+
+func.func @resolve_binding_subspan_dyn_dims(%arg0 : index, %arg1 : index) -> (!util.buffer, index, index, index, index, index) {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xindex>{%arg0, %arg1}
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x?xindex> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_binding_subspan_dyn_dims(
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//     CHECK:   %[[CAST:.+]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
+//     CHECK:   return %[[CAST]], %{{.+}}, %arg0, %arg1, %arg1, %[[C1]]
+
+// -----
+
+func.func @resolve_alloca_static() -> (!util.buffer, index, index, index, index, index) {
+  %0 = memref.alloca() : memref<512x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_alloca_static()
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//     CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+
+
+// -----
+
+func.func @resolve_alloca_dynamic(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
+  %0 = memref.alloca(%arg0) : memref<?x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_alloca_dynamic(
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//     CHECK:   return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]]
+
+// -----
+
+memref.global "private" constant @__constant_2xi32 : memref<512x384xf32> = dense<0.0>
+
+func.func @resolve_global() -> (!util.buffer, index, index, index, index, index) {
+  %0 = memref.get_global @__constant_2xi32 : memref<512x384xf32>
+  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+//     CHECK: func @resolve_global(
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//     CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+
+// -----
+
+#map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
+func.func @resolve_subview_memref(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (memref<f32>, index, index, index, index, index) {
+    %0 = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : memref<384x128xf32> to memref<64x64xf32, #map0>
+    %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<64x64xf32, #map0> -> memref<f32>, index, index, index, index, index
+    return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
+  }
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)>
+//     CHECK: func @resolve_subview_memref(
+// CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:2, %[[BASE_STRIDES:.+]]:2 = memref.extract_strided_metadata %arg0
+//     CHECK:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%arg1, %arg2] 
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]]
+
+// -----
+
+#map0 = affine_map<(d0)[s0] -> (d0 * 128 + s0)>
+func.func @resolve_subview_rankreducing_memref(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (memref<f32>, index, index, index) {
+  %0 = memref.subview %arg0[%arg1, %arg2] [64, 1] [1, 1] : memref<384x128xf32> to memref<64xf32, #map0>
+  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %0 : memref<64xf32, #map0> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %size, %stride : memref<f32>, index, index, index
+}
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)>
+//     CHECK: func @resolve_subview_rankreducing_memref(
+// CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:2, %[[BASE_STRIDES:.+]]:2 = memref.extract_strided_metadata %arg0
+// CHECK-DAG:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%arg1, %arg2]
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C64]], %[[C128]]
 
 // -----
 
@@ -51,109 +198,110 @@
 //
 // Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
 //
-// CHECK-LABEL: @resolve_subview_rankreducing_not_at_the_end
-func.func @resolve_subview_rankreducing_not_at_the_end(%arg0: memref<8x16x4xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[BASE_BUFFER:.*]], %[[BASE_OFFSET:.*]], %[[BASE_SIZES:.*]]:3, %[[BASE_STRIDES:.*]]:3 = vmvx.get_buffer_descriptor %arg0
-  // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
-  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-  // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index
-  // CHECK:     %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index
-  // CHECK:     %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index
-  // CHECK:     %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index
-  // CHECK:     return %[[BASE_BUFFER]], %[[I3]], %[[C6]], %[[C3]], %[[BASE_STRIDES]]#1, %[[BASE_STRIDES]]#2
+func.func @resolve_subview_rankreducing_not_at_the_end_memref(%arg0: memref<8x16x4xf32>, %arg1 : index, %arg2 : index) -> (memref<f32>, index, index, index, index, index) {
 
   %0 = memref.subview %arg0[%arg1, %arg2, 0] [1, 6, 3] [1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4,1], offset : ?>>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<6x3xf32, strided<[4,1], offset : ?>> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<6x3xf32, strided<[4,1], offset : ?>> -> memref<f32>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
 }
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 64 + s1 * 4)>
+//     CHECK: func @resolve_subview_rankreducing_not_at_the_end_memref(
+// CHECK-DAG:   %[[C6:.+]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[BASE_BUFFER:.+]], %[[BASE_OFFSET:.+]], %[[BASE_SIZES:.+]]:3, %[[BASE_STRIDES:.+]]:3 = memref.extract_strided_metadata %arg0
+//     CHECK:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%arg1, %arg2]
+//     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
+
 
 // -----
 
-// CHECK-LABEL: @resolve_binding_subspan_zero_offset
-func.func @resolve_binding_subspan_zero_offset() -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
-  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  //     CHECK: %[[CAST:.*]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
-  //     CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+func.func @resolve_binding_subspan_zero_offset_memref() -> (memref<f32>, index, index, index, index, index) {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<512x384xf32>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xf32> -> memref<f32>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
 }
+//     CHECK: func @resolve_binding_subspan_zero_offset_memref(
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%[[C0]]) : memref<f32>
+//     CHECK:   return %[[BINDING]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
 
 // -----
 
-// CHECK-LABEL: @resolve_binding_subspan_offset_index
-func.func @resolve_binding_subspan_offset_index(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
-  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[INDEX_SIZE:.*]] = util.sizeof index
-  // CHECK-DAG: %[[OFFSET:.*]] = arith.divui %arg0, %[[INDEX_SIZE]] : index
-  //     CHECK: %[[CAST:.*]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
-  //     CHECK: return %[[CAST]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
-  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<512x384xindex>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xindex> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+func.func @resolve_binding_subspan_offset_index_memref(%arg0 : index) -> (memref<index>, index, index, index, index, index) {
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<512x384xindex, strided<[384, 1], offset:?>>
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xindex, strided<[384, 1], offset:?>> -> memref<index>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<index>, index, index, index, index, index
 }
+//     CHECK: func @resolve_binding_subspan_offset_index_memref(
+// CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+// CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//     CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<index>
+//     CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
 
 // -----
 
-// CHECK-LABEL: @resolve_binding_subspan_dyn_dims
-func.func @resolve_binding_subspan_dyn_dims(%arg0 : index, %arg1 : index) -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  //     CHECK: %[[CAST:.*]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
-  //     CHECK: return %[[CAST]], %{{.*}}, %arg0, %arg1, %arg1, %[[C1]]
+func.func @resolve_binding_subspan_dyn_dims_memref(%arg0 : index, %arg1 : index) -> (memref<index>, index, index, index, index, index) {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xindex>{%arg0, %arg1}
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x?xindex> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<?x?xindex> -> memref<index>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<index>, index, index, index, index, index
 }
+//     CHECK: func @resolve_binding_subspan_dyn_dims_memref(
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//     CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%[[C0]]) : memref<index>
+//     CHECK:   return %[[BINDING]], %[[C0]], %arg0, %arg1, %arg1, %[[C1]]
 
 // -----
 
-// CHECK-LABEL: @resolve_alloca_static
-func.func @resolve_alloca_static() -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
-  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
-  //     CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+func.func @resolve_alloca_static_memref() -> (memref<f32>, index, index, index, index, index) {
   %0 = memref.alloca() : memref<512x384xf32>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xf32> -> memref<f32>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
 }
+// CHECK-LABEL: func @resolve_alloca_static_memref(
+//   CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+//   CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//       CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
 
 // -----
 
-// CHECK-LABEL: @resolve_alloca_dynamic
-func.func @resolve_alloca_dynamic(%arg0 : index) -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
-  //     CHECK: return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]]
+func.func @resolve_alloca_dynamic_memref(%arg0 : index) -> (memref<f32>, index, index, index, index, index) {
   %0 = memref.alloca(%arg0) : memref<?x384xf32>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<?x384xf32> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<?x384xf32> -> memref<f32>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
 }
+// CHECK-LABEL: func @resolve_alloca_dynamic_memref(
+//   CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//       CHECK:   return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]]
 
 // -----
 
-// CHECK-LABEL: @resolve_global
 memref.global "private" constant @__constant_2xi32 : memref<512x384xf32> = dense<0.0>
 
-func.func @resolve_global() -> (!util.buffer, index, index, index, index, index) {
-  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
-  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  //     CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast
-  //     CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+func.func @resolve_global_memref() -> (memref<f32>, index, index, index, index, index) {
   %0 = memref.get_global @__constant_2xi32 : memref<512x384xf32>
-  %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xf32> -> memref<f32>, index, index, index, index, index
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<f32>, index, index, index, index, index
 }
+// CHECK-LABEL: func @resolve_global_memref()
+//   CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+//   CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[CAST:.+]] = builtin.unrealized_conversion_cast
+//       CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]