Minor fixes  found while enabling micro-kernel usage e2e. (#12354)

Rename vmvx.*.i8.i8.i32 -> vmvx.*.i8i8i32.
Rename vmvx.*.f32.f32.f32 -> vmvx.*.f32f32f32.
Fix insertion point for hal.interface.binding.subspan in ResolveBufferDescriptors pass.
Set the vm.import.module attribute to vmvx in ukernel function declaration.
Remove offset from newly created hal.interface.binding.subspan in ResolveBufferDesscriptors.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
index 345b420..d790a21 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
@@ -55,10 +55,10 @@
   %dim_1 = memref.dim %arg0, %c0 : memref<?x?xf32>
   %dim_2 = memref.dim %arg1, %c1 : memref<?x?xf32>
   %dim_3 = memref.dim %arg2, %c1 : memref<?x?xf32>
-  iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32" ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) (%dim_1, %dim_2, %dim_3, %c0_i32 : index, index, index, i32)
+  iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32" ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) (%dim_1, %dim_2, %dim_3, %c0_i32 : index, index, index, i32)
   return
 }
-// CHECK-LABEL: func.func private @vmvx.matmul.f32.f32.f32
+// CHECK-LABEL: func.func private @vmvx.matmul.f32f32f32
 //  CHECK-SAME:     (memref<f32>, index, index, memref<f32>, index, index,
 //  CHECK-SAME:     memref<f32>, index, index, index, index, index, i32)
 // CHECK-LABEL: func.func @generic_ukernel
@@ -74,7 +74,7 @@
 //       CHECK:   %[[BASE0:.+]], %[[OFFSET0:.+]], %[[SIZE0:.+]]:2, %[[STRIDES0:.+]]:2 = memref.extract_strided_metadata %[[ARG0]]
 //       CHECK:   %[[BASE1:.+]], %[[OFFSET1:.+]], %[[SIZE1:.+]]:2, %[[STRIDES1:.+]]:2 = memref.extract_strided_metadata %[[ARG1]]
 //       CHECK:   %[[BASE2:.+]], %[[OFFSET2:.+]], %[[SIZE2:.+]]:2, %[[STRIDES2:.+]]:2 = memref.extract_strided_metadata %[[ARG2]]
-//       CHECK:   call @vmvx.matmul.f32.f32.f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
+//       CHECK:   call @vmvx.matmul.f32f32f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
 //  CHECK-SAME:       %[[BASE1]], %[[OFFSET1]], %[[STRIDES1]]#0
 //  CHECK-SAME:       %[[BASE2]], %[[OFFSET2]], %[[STRIDES2]]#0
 //  CHECK-SAME:       %[[D0]], %[[D1]], %[[D2]], %[[C0_I32]])
@@ -87,7 +87,7 @@
       outs(%arg2 : memref<?x?x?x?xf32>) accumulate(false)
   return
 }
-// CHECK-LABEL: func.func private @vmvx.mmt4d.f32.f32.f32
+// CHECK-LABEL: func.func private @vmvx.mmt4d.f32f32f32
 //  CHECK-SAME:     (memref<f32>, index, index, memref<f32>, index, index,
 //  CHECK-SAME:     memref<f32>, index, index, index, index, index, i32, i32, i32, i32)
 // CHECK-LABEL: func.func @mmt4d_ukernel(
@@ -111,7 +111,7 @@
 //   CHECK-DAG:   %[[D5:.+]] = memref.dim %[[ARG0]], %[[C3]]
 //   CHECK-DAG:   %[[D5_I32:.+]] = arith.index_cast %[[D5]]
 //   CHECK-DAG:   %[[C0_I32:.+]] = arith.constant 0 : i32
-//       CHECK:   call @vmvx.mmt4d.f32.f32.f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
+//       CHECK:   call @vmvx.mmt4d.f32f32f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
 //  CHECK-SAME:       %[[BASE1]], %[[OFFSET1]], %[[STRIDES1]]#0
 //  CHECK-SAME:       %[[BASE2]], %[[OFFSET2]], %[[STRIDES2]]#0
 //  CHECK-SAME:       %[[D0]], %[[D1]], %[[D2]],
@@ -126,4 +126,4 @@
   return
 }
 // CHECK-LABEL: func @mmt4d_ukernel_i8i8i32(
-//       CHECK:   call @vmvx.mmt4d.i8.i8.i32
+//       CHECK:   call @vmvx.mmt4d.i8i8i32
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
index a7288cb..9bde8a8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
@@ -50,6 +50,10 @@
     rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
     fnDecl = rewriter.create<func::FuncOp>(loc, fnName, functionType);
     SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private);
+    // TODO(#12327): Based on description in the issue, add an attribute
+    // `vm.import.module` and set it to `vmvx`. This only works on `vmvx`
+    // backend (obviously), but is enough to unblock while the proper fix lands.
+    fnDecl->setAttr("vm.import.module", rewriter.getStringAttr("vmvx"));
   } else if (fnDecl.getFunctionType() != functionType) {
     return rewriter.notifyMatchFailure(
         op, llvm::formatv("mismatch in function type computed during lowering "
@@ -288,10 +292,10 @@
   std::string fnName = "vmvx.mmt4d.";
   switch (matmulType.value()) {
     case MatmulType::I8I8I32:
-      fnName.append("i8.i8.i32");
+      fnName.append("i8i8i32");
       break;
     case MatmulType::F32F32F32:
-      fnName.append("f32.f32.f32");
+      fnName.append("f32f32f32");
       break;
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
index cacbdf6..de030ae 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
@@ -66,10 +66,10 @@
   std::string fnName = "";
   switch (matmulType.value()) {
     case MatmulType::I8I8I32:
-      fnName = "vmvx.matmul.i8.i8.i32";
+      fnName = "vmvx.matmul.i8i8i32";
       break;
     case MatmulType::F32F32F32:
-      fnName = "vmvx.matmul.f32.f32.f32";
+      fnName = "vmvx.matmul.f32f32f32";
       break;
   }
 
@@ -108,10 +108,10 @@
   Type outElemType = outType.getElementType();
   if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) &&
       outElemType.isSignlessInteger(32)) {
-    fnName = "vmvx.matmul.i8.i8.i32";
+    fnName = "vmvx.matmul.i8i8i32";
   } else if (lhsElemType.isF32() && rhsElemType.isF32() &&
              outElemType.isF32()) {
-    fnName = "vmvx.matmul.f32.f32.f32";
+    fnName = "vmvx.matmul.f32f32f32";
   }
   if (fnName.empty()) {
     return rewriter.notifyMatchFailure(op,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
index c28372d..ff9aa0c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
@@ -16,7 +16,7 @@
 //  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32"
+//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32"
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]] :
 // CHECK-SAME:       (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
@@ -46,7 +46,7 @@
 //  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32"
+//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32"
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[EMPTY]] :
 // CHECK-SAME:       (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
@@ -70,7 +70,7 @@
 //  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.i8.i8.i32"
+//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.i8i8i32"
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]] :
 // CHECK-SAME:       (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index a75a6e4..1304f2d 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -89,6 +89,50 @@
   return strides;
 }
 
+static FailureOr<DescriptorInfo> resolveBufferDescriptorForInterfaceBinding(
+    IREE::HAL::InterfaceBindingSubspanOp binding, RewriterBase &rewriter,
+    Location loc) {
+  auto memRefType = binding.getResult().getType().template cast<MemRefType>();
+  int rank = memRefType.getRank();
+  DescriptorInfo resultDescriptor;
+
+  // Compute sizes.
+  auto dynamicDimIt = binding.getDynamicDims().begin();
+  for (int i = 0; i < rank; ++i) {
+    if (memRefType.isDynamicDim(i)) {
+      resultDescriptor.sizes.push_back(*dynamicDimIt);
+      dynamicDimIt++;
+    } else {
+      resultDescriptor.sizes.push_back(
+          rewriter.getIndexAttr(memRefType.getDimSize(i)));
+    }
+  }
+  // Strides.
+  resultDescriptor.strides =
+      getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
+
+  // Offset.
+  Type elementType = memRefType.getElementType();
+  OpFoldResult elementWidth =
+      TypeSwitch<Type, OpFoldResult>(elementType)
+          .Case<ComplexType, IntegerType, FloatType>(
+              [&](auto type) -> OpFoldResult {
+                return rewriter.getIndexAttr(
+                    IREE::Util::getRoundedElementByteWidth(
+                        memRefType.getElementType()));
+              })
+          .Default([&](Type t) -> OpFoldResult {
+            return rewriter.create<IREE::Util::SizeOfOp>(loc, elementType)
+                .getResult();
+          });
+  AffineExpr s0, s1;
+  bindSymbols(rewriter.getContext(), s0, s1);
+  resultDescriptor.offset = makeComposedFoldedAffineApply(
+      rewriter, loc, s0.floorDiv(s1),
+      ArrayRef<OpFoldResult>{binding.getByteOffset(), elementWidth});
+  return resultDescriptor;
+}
+
 static FailureOr<DescriptorInfo> resolveBufferDescriptorForAllocation(
     memref::AllocaOp alloca, RewriterBase &rewriter, Location loc) {
   DescriptorInfo resultDescriptor;
@@ -260,34 +304,14 @@
     if (!binding) return failure();
 
     auto loc = op.getLoc();
-
-    auto memRefType = binding.getResult().getType().cast<MemRefType>();
-    int rank = memRefType.getRank();
-    DescriptorInfo resultDescriptor;
-
-    // Compute sizes.
-    auto dynamicDimIt = binding.getDynamicDims().begin();
-    for (int i = 0; i < rank; ++i) {
-      if (memRefType.isDynamicDim(i)) {
-        resultDescriptor.sizes.push_back(*dynamicDimIt);
-        dynamicDimIt++;
-      } else {
-        resultDescriptor.sizes.push_back(
-            rewriter.getIndexAttr(memRefType.getDimSize(i)));
-      }
+    FailureOr<DescriptorInfo> resultDescriptor =
+        resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
+    if (failed(resultDescriptor)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve descriptor with source being binding op");
     }
 
-    // Strides.
-    resultDescriptor.strides =
-        getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
-
-    // Offset.
-    auto elementSize =
-        rewriter.create<IREE::Util::SizeOfOp>(loc, memRefType.getElementType());
-    resultDescriptor.offset = rewriter.createOrFold<arith::DivUIOp>(
-        loc, binding.getByteOffset(), elementSize);
-
-    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
 
     // Base buffer.
     rewriter.replaceAllUsesWith(
@@ -316,32 +340,19 @@
     if (memRefType.getRank() < 1) return failure();
 
     auto loc = op.getLoc();
-    int rank = memRefType.getRank();
-    DescriptorInfo resultDescriptor;
-
-    // Compute sizes.
-    auto dynamicDimIt = binding.getDynamicDims().begin();
-    for (int i = 0; i < rank; ++i) {
-      if (memRefType.isDynamicDim(i)) {
-        resultDescriptor.sizes.push_back(*dynamicDimIt);
-        dynamicDimIt++;
-      } else {
-        resultDescriptor.sizes.push_back(
-            rewriter.getIndexAttr(memRefType.getDimSize(i)));
-      }
+    FailureOr<DescriptorInfo> resultDescriptor =
+        resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
+    if (failed(resultDescriptor)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve descriptor with source being binding op");
     }
 
-    // Strides.
-    resultDescriptor.strides =
-        getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
-    resultDescriptor.offset = rewriter.getIndexAttr(0);
-
-    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
+    replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
 
     // Base buffer. Use a 1D memref for hal.interface.binding.subspan.
     AffineMap mulMap = getMulMap(rewriter.getContext());
     OpFoldResult linearizedMemrefSize = rewriter.getIndexAttr(1);
-    for (auto size : resultDescriptor.sizes) {
+    for (auto size : resultDescriptor->sizes) {
       linearizedMemrefSize = makeComposedFoldedAffineApply(
           rewriter, loc, mulMap, {linearizedMemrefSize, size});
     }
@@ -349,13 +360,16 @@
     SmallVector<Value> dynamicLinearShape;
     dispatchIndexOpFoldResult(linearizedMemrefSize, dynamicLinearShape,
                               staticLinearShape);
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(binding);
+
+    Value newOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     Value linearInterfaceBinding =
         rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
             loc, op.getBaseBuffer().getType(), binding.getSetAttr(),
             binding.getBindingAttr(), binding.getDescriptorTypeAttr(),
-            binding.getByteOffset(),
-            /*dynamicDims =*/ValueRange{}, binding.getAlignmentAttr(),
-            binding.getDescriptorFlagsAttr());
+            newOffset, /*dynamicDims =*/ValueRange{},
+            binding.getAlignmentAttr(), binding.getDescriptorFlagsAttr());
     rewriter.replaceAllUsesWith(op.getBaseBuffer(), linearInterfaceBinding);
 
     rewriter.eraseOp(op);
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index 8dc786e..7505745 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -81,12 +81,13 @@
   %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xindex> -> !util.buffer, index, index, index, index, index
   return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
 }
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 floordiv s1)>
 //     CHECK: func @resolve_binding_subspan_offset_index(
 // CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
 // CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
 // CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 // CHECK-DAG:   %[[INDEX_SIZE:.+]] = util.sizeof index
-// CHECK-DAG:   %[[OFFSET:.+]] = arith.divui %arg0, %[[INDEX_SIZE]] : index
+// CHECK-DAG:   %[[OFFSET:.+]] = affine.apply #map()[%arg0, %[[INDEX_SIZE]]]
 //     CHECK:   %[[CAST:.+]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
 //     CHECK:   return %[[CAST]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
 
@@ -214,7 +215,6 @@
 //     CHECK:   %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%arg1, %arg2]
 //     CHECK:   return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
 
-
 // -----
 
 func.func @resolve_binding_subspan_zero_offset_memref() -> (memref<f32>, index, index, index, index, index) {
@@ -238,13 +238,16 @@
   %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xindex, strided<[384, 1], offset:?>> -> memref<index>, index, index, index, index, index
   return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<index>, index, index, index, index, index
 }
+//     CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 floordiv s1)>
 //     CHECK: func @resolve_binding_subspan_offset_index_memref(
 // CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
 // CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
 // CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//     CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<index>
-//     CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+//     CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%[[C0]]) : memref<index>
+//     CHECK:   %[[SIZEOF:.+]] = util.sizeof index
+//     CHECK:   %[[OFFSET:.+]] = affine.apply #[[MAP]]()[%arg0, %[[SIZEOF]]]
+//     CHECK:   return %[[BINDING]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
 
 // -----