Handle memref<index> in VM conversion. (#9897)

* Introduces a util.sizeof op that allows the actual concrete size of a type to remain unknown until final conversion to the target.
* Adds a VM conversion for `util.sizeof index` to convert to the actual index size of the target.
* Adapts util.buffer->vm.buffer lowerings to properly size load/store/fill of index buffers based on the index size of the target.
* Enables remainder of tosa_ops tests for microkernels.

A followon will break out the memref->util conversion to be done as a final stage of vmvx lowering.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
index e74bf2f8..51b9504 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
@@ -52,8 +52,8 @@
   // may elect to represent all i8 registers as i32, but this does not mean
   // that all memrefs are widened from i8 to i32).
   auto elementType = memrefType.getElementType();
-  auto elementSize = rewriter.createOrFold<arith::ConstantIndexOp>(
-      loc, IREE::Util::getRoundedElementByteWidth(elementType));
+  auto elementSize =
+      rewriter.createOrFold<IREE::Util::SizeOfOp>(loc, elementType);
 
   // Rank 1 memrefs are just offset by their element width by the offset.
   auto elementCount = indices.front();
@@ -155,11 +155,12 @@
       return rewriter.notifyMatchFailure(
           allocaOp, "unable to create buffers for dynamic shapes");
     }
-    int64_t memRefLength =
-        type.getNumElements() *
-        IREE::Util::getRoundedElementByteWidth(type.getElementType());
-    Value allocationSize = rewriter.create<arith::ConstantIndexOp>(
-        allocaOp.getLoc(), memRefLength);
+    auto numElements = rewriter.create<arith::ConstantIndexOp>(
+        allocaOp.getLoc(), type.getNumElements());
+    auto elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+        allocaOp.getLoc(), type.getElementType());
+    auto allocationSize = rewriter.createOrFold<arith::MulIOp>(
+        allocaOp.getLoc(), numElements, elementSize);
     rewriter.replaceOpWithNewOp<IREE::Util::BufferAllocOp>(
         allocaOp, rewriter.getType<IREE::Util::BufferType>(), allocationSize);
     return success();
@@ -177,8 +178,8 @@
     }
     auto elementType =
         dimOp.getSource().getType().cast<MemRefType>().getElementType();
-    Value elementSize = rewriter.create<arith::ConstantIndexOp>(
-        dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(elementType));
+    Value elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+        dimOp.getLoc(), elementType);
     Value bufferSize = rewriter.create<IREE::Util::BufferSizeOp>(
         dimOp.getLoc(), rewriter.getIndexType(), adaptor.getSource());
     rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(dimOp, bufferSize,
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
index fa7a099..2d94678 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -60,6 +60,19 @@
 }
 
 // -----
+// CHECK-LABEL: @alloc_index
+// CHECK-SAME: (%[[IDX0:.+]]: index) -> !util.buffer {
+func.func @alloc_index(%idx0: index) -> memref<4xindex> {
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[SIZEOF:.*]] = util.sizeof index
+  // CHECK: %[[SZ:.*]] = arith.muli %[[SIZEOF]], %[[C4]]
+  // CHECK: %[[BUFFER:.*]] = util.buffer.alloc uninitialized : !util.buffer{%[[SZ]]}
+  %0 = memref.alloca() : memref<4xindex>
+  // CHECK: return %[[BUFFER]]
+  return %0 : memref<4xindex>
+}
+
+// -----
 // CHECK-LABEL: @load_store_f32
 // CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index) -> f32 {
 func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
@@ -113,6 +126,22 @@
 }
 
 // -----
+// CHECK-LABEL: @load_store_index
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[VALUE:.+]]: index) -> index {
+func.func @load_store_index(%buffer: memref<?xindex>, %idx0: index, %idx1: index, %value: index) -> index {
+  // CHECK-DAG: %[[SIZEOF:.*]] = util.sizeof index
+  // CHECK-DAG: %[[SZ:.*]] = util.buffer.size %[[BUFFER]]
+  // CHECK-DAG: %[[OFS0:.*]] = arith.muli %[[SIZEOF]], %[[IDX0]] : index
+  // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFS0]]] : index -> !util.buffer{%[[SZ]]}
+  memref.store %value, %buffer[%idx0] : memref<?xindex>
+  // CHECK: %[[OFS1:.*]] = arith.muli %[[SIZEOF]], %[[IDX1]] : index
+  // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]]] : !util.buffer{%[[SZ]]} -> index
+  %1 = memref.load %buffer[%idx1] : memref<?xindex>
+  // CHECK: return %[[LD]]
+  return %1 : index
+}
+
+// -----
 // CHECK-LABEL: @dim_i16
 // CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index) -> index {
 func.func @dim_i16(%buffer: memref<?xi16>, %idx0: index) -> index {
@@ -123,3 +152,15 @@
   // CHECK: return %[[DV]]
   return %0 : index
 }
+
+// -----
+// CHECK-LABEL: @dim_index
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index) -> index {
+func.func @dim_index(%buffer: memref<?xindex>, %idx0: index) -> index {
+  // CHECK: %[[SIZEOF:.*]] = util.sizeof index
+  // CHECK: %[[SZ:.*]] = util.buffer.size %[[BUFFER]] : !util.buffer
+  // CHECK: %[[DV:.*]] = arith.floordivsi %[[SZ]], %[[SIZEOF]] : index
+  %0 = memref.dim %buffer, %idx0 : memref<?xindex>
+  // CHECK: return %[[DV]]
+  return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
index a425d03..d3bd131 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
@@ -37,7 +37,7 @@
 def Util_Tensor : TensorOf<[Util_Element]>;
 def Util_Primitive : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
 
-def Util_FillPattern : AnyTypeOf<[AnyInteger, AnyFloat]>;
+def Util_FillPattern : AnyTypeOf<[AnyInteger, AnyFloat, Index]>;
 
 def Util_Offset : TypeAlias<Index>;
 def Util_Size : TypeAlias<Index>;
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index ab7721c..59b2c8d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -391,6 +391,19 @@
 }
 
 //===----------------------------------------------------------------------===//
+// util.sizeof
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SizeOfOp::fold(ArrayRef<Attribute> operands) {
+  Type t = getSizedType();
+  if (t.isa<IntegerType>() || t.isa<FloatType>()) {
+    return IntegerAttr::get(IndexType::get(getContext()),
+                            getRoundedElementByteWidth(t));
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
 // Compiler hints
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 26815e2..14df90e 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -273,6 +273,33 @@
   let hasFolder = 1;
 }
 
+def Util_SizeOfOp : Util_PureOp<"sizeof"> {
+  let summary = [{returns the size in bytes of a datatype}];
+  let description = [{
+    Most datatypes have a static size at all layers of the compilation stack.
+    However, those that only have a size for certain lowering flows can be
+    challenging. This op represents such sizes in a way that can be specialized
+    later.
+
+    Returns the size in bytes, rounded up to the next whole byte of the
+    specified type. This op will fold to a constant index value for IntegerType
+    and FloatType. All others are not folded.
+  }];
+
+  let arguments = (ins
+    TypeAttr:$sizedType
+  );
+  let results = (outs
+    Index:$size
+  );
+
+  let assemblyFormat = [{
+    $sizedType attr-dict-with-keyword
+  }];
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Compiler hints
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
index 0de88a4..23561af 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
@@ -91,3 +91,30 @@
   // CHECK: return %[[SUM_ALIGNED]]
   return %result : index
 }
+
+// -----
+
+// CHECK-LABEL: @sizeofWholeInt
+func.func @sizeofWholeInt() -> index {
+  // CHECK: = arith.constant 4 : index
+  %0 = util.sizeof i32
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @sizeofSubByteInt
+func.func @sizeofSubByteInt() -> index {
+  // CHECK: = arith.constant 2 : index
+  %0 = util.sizeof i12
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @sizeofFloat
+func.func @sizeofFloat() -> index {
+  // CHECK: = arith.constant 4 : index
+  %0 = util.sizeof f32
+  return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
index 9f68f6d..8bf52f8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
@@ -15,3 +15,12 @@
   %result = util.align %arg0, %arg1 : i32
   return
 }
+
+// -----
+
+// CHECK-LABEL: @sizeofUnfoldable
+func.func @sizeofUnfoldable() -> index {
+  // CHECK: = util.sizeof index
+  %0 = util.sizeof index
+  return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
index b622fc8..d2b4c05 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
@@ -68,6 +68,34 @@
   }
 };
 
+//===----------------------------------------------------------------------===//
+// util.sizeof
+//===----------------------------------------------------------------------===//
+
+/// For a `sizeof index` operation, invokes the type converter to derive the
+/// concrete type for index and rewrites to that. This allows us to do late
+/// resolution of the size of the index type at the point of conversion to VM
+/// where it is known.
+struct FixateIndexSizeofConversion
+    : public OpConversionPattern<IREE::Util::SizeOfOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::Util::SizeOfOp sizeofOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type sizedType = sizeofOp.getSizedType();
+    if (sizedType.isa<IndexType>()) {
+      Type converted = getTypeConverter()->convertType(sizedType);
+      if (converted) {
+        Value newSizeof = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+            sizeofOp.getLoc(), converted);
+        rewriter.replaceOp(sizeofOp, newSizeof);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 }  // namespace
 
 void populateUtilAlignmentToVMPatterns(MLIRContext *context,
@@ -75,8 +103,10 @@
                                        TypeConverter &typeConverter,
                                        RewritePatternSet &patterns) {
   conversionTarget.addIllegalOp<IREE::Util::AlignOp>();
+  conversionTarget.addIllegalOp<IREE::Util::SizeOfOp>();
 
-  patterns.insert<AlignOpConversion>(typeConverter, context);
+  patterns.insert<AlignOpConversion, FixateIndexSizeofConversion>(typeConverter,
+                                                                  context);
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
index 3c50f36..9f13d7d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
@@ -154,6 +154,11 @@
       IREE::Util::BufferFillOp fillOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     auto oldType = fillOp.getPattern().getType();
+    auto newType = adaptor.getPattern().getType();
+    if (oldType.isa<IndexType>()) {
+      // Use the actual converted type for IndexType.
+      oldType = newType;
+    }
     auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
     auto byteLength = castToI64(adaptor.getLength(), rewriter);
     auto pattern = adaptor.getPattern();
@@ -196,6 +201,9 @@
       ConversionPatternRewriter &rewriter) const override {
     auto oldType = loadOp.getResult().getType();
     auto newType = getTypeConverter()->convertType(oldType);
+    if (oldType.isa<IndexType>()) {
+      oldType = newType;
+    }
     auto byteOffset = castToI64(adaptor.getSourceOffset(), rewriter);
     if (auto integerType = oldType.dyn_cast<IntegerType>()) {
       if (integerType.isInteger(1) || integerType.isInteger(8)) {
@@ -245,6 +253,10 @@
       IREE::Util::BufferStoreOp storeOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     auto oldType = storeOp.getSource().getType();
+    auto newType = adaptor.getSource().getType();
+    if (oldType.isa<IndexType>()) {
+      oldType = newType;
+    }
     auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
     if (oldType.isInteger(1) || oldType.isInteger(8)) {
       rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
index 43e86c0..d157967 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
@@ -42,3 +42,13 @@
   //CHECK-DAG: vm.return %3 : i64
   return %result : i64
 }
+
+// -----
+
+// CHECK-LABEL: @utilSizeOfIndex
+func.func @utilSizeOfIndex() ->  (index) {
+  // CHECK: %[[SIZEOF:.*]] = vm.const.i32 4
+  %0 = util.sizeof index
+  // CHECK: vm.return %[[SIZEOF]]
+  return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
index bfd2936..a050de9 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
@@ -150,6 +150,18 @@
 
 // -----
 
+// CHECK-LABEL: @buffer_fill_index
+func.func @buffer_fill_index(%arg0: !util.buffer, %arg1: index, %arg2: index) {
+  %c100 = arith.constant 100 : index
+  %c200 = arith.constant 200 : index
+  // CHECK-32: vm.buffer.fill.i32
+  // CHECK-64: vm.buffer.fill.i64
+  util.buffer.fill %arg2, %arg0[%c100 for %c200] : index -> !util.buffer{%arg1}
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @buffer_load_i1
 func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 {
   %c100 = arith.constant 100 : index
@@ -189,6 +201,17 @@
 
 // -----
 
+// CHECK-LABEL: @buffer_load_index
+func.func @buffer_load_index(%arg0: !util.buffer, %arg1: index) -> index {
+  %c100 = arith.constant 100 : index
+  // CHECK-32: vm.buffer.load.i32
+  // CHECK-64: vm.buffer.load.i64
+  %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> index
+  return %0 : index
+}
+
+// -----
+
 // CHECK-LABEL: @buffer_store_i1
 func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) {
   %c100 = arith.constant 100 : index
@@ -222,3 +245,14 @@
   util.buffer.store %arg2, %arg0[%c100] : i64 -> !util.buffer{%arg1}
   return
 }
+
+// -----
+
+// CHECK-LABEL: @buffer_store_index
+func.func @buffer_store_index(%arg0: !util.buffer, %arg1: index, %arg2: index) {
+  %c100 = arith.constant 100 : index
+  // CHECK-32: vm.buffer.store.i32
+  // CHECK-64: vm.buffer.store.i64
+  util.buffer.store %arg2, %arg0[%c100] : index -> !util.buffer{%arg1}
+  return
+}
diff --git a/tests/e2e/tosa_ops/BUILD b/tests/e2e/tosa_ops/BUILD
index 93fc1a3..c380539 100644
--- a/tests/e2e/tosa_ops/BUILD
+++ b/tests/e2e/tosa_ops/BUILD
@@ -147,6 +147,7 @@
         "exp.mlir",
         "floor.mlir",
         "fully_connected.mlir",
+        "gather.mlir",
         "greater.mlir",
         "greater_equal.mlir",
         "if.mlir",
@@ -168,16 +169,12 @@
         "select.mlir",
         "sigmoid.mlir",
         "sub.mlir",
+        "table.mlir",
         "tanh.mlir",
         "transpose.mlir",
         "while.mlir",
     ],
     include = ["*.mlir"],
-    exclude = [
-        # Decompositions produce tensor<index> which is not handled properly.
-        "gather.mlir",
-        "table.mlir",
-    ],
 )
 
 iree_check_single_backend_test_suite(
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 853fd13..11695d1 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -131,6 +131,7 @@
     "exp.mlir"
     "floor.mlir"
     "fully_connected.mlir"
+    "gather.mlir"
     "greater.mlir"
     "greater_equal.mlir"
     "if.mlir"
@@ -152,6 +153,7 @@
     "select.mlir"
     "sigmoid.mlir"
     "sub.mlir"
+    "table.mlir"
     "tanh.mlir"
     "transpose.mlir"
     "while.mlir"