Minor fixes found while enabling micro-kernel usage e2e. (#12354)
Rename vmvx.*.i8.i8.i32 -> vmvx.*.i8i8i32.
Rename vmvx.*.f32.f32.f32 -> vmvx.*.f32f32f32.
Fix insertion point for hal.interface.binding.subspan in ResolveBufferDescriptors pass.
Set the vm.import.module attribute to vmvx in ukernel function declaration.
Remove offset from newly created hal.interface.binding.subspan in ResolveBufferDesscriptors.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
index 345b420..d790a21 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir
@@ -55,10 +55,10 @@
%dim_1 = memref.dim %arg0, %c0 : memref<?x?xf32>
%dim_2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%dim_3 = memref.dim %arg2, %c1 : memref<?x?xf32>
- iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32" ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) (%dim_1, %dim_2, %dim_3, %c0_i32 : index, index, index, i32)
+ iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32" ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) (%dim_1, %dim_2, %dim_3, %c0_i32 : index, index, index, i32)
return
}
-// CHECK-LABEL: func.func private @vmvx.matmul.f32.f32.f32
+// CHECK-LABEL: func.func private @vmvx.matmul.f32f32f32
// CHECK-SAME: (memref<f32>, index, index, memref<f32>, index, index,
// CHECK-SAME: memref<f32>, index, index, index, index, index, i32)
// CHECK-LABEL: func.func @generic_ukernel
@@ -74,7 +74,7 @@
// CHECK: %[[BASE0:.+]], %[[OFFSET0:.+]], %[[SIZE0:.+]]:2, %[[STRIDES0:.+]]:2 = memref.extract_strided_metadata %[[ARG0]]
// CHECK: %[[BASE1:.+]], %[[OFFSET1:.+]], %[[SIZE1:.+]]:2, %[[STRIDES1:.+]]:2 = memref.extract_strided_metadata %[[ARG1]]
// CHECK: %[[BASE2:.+]], %[[OFFSET2:.+]], %[[SIZE2:.+]]:2, %[[STRIDES2:.+]]:2 = memref.extract_strided_metadata %[[ARG2]]
-// CHECK: call @vmvx.matmul.f32.f32.f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
+// CHECK: call @vmvx.matmul.f32f32f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
// CHECK-SAME: %[[BASE1]], %[[OFFSET1]], %[[STRIDES1]]#0
// CHECK-SAME: %[[BASE2]], %[[OFFSET2]], %[[STRIDES2]]#0
// CHECK-SAME: %[[D0]], %[[D1]], %[[D2]], %[[C0_I32]])
@@ -87,7 +87,7 @@
outs(%arg2 : memref<?x?x?x?xf32>) accumulate(false)
return
}
-// CHECK-LABEL: func.func private @vmvx.mmt4d.f32.f32.f32
+// CHECK-LABEL: func.func private @vmvx.mmt4d.f32f32f32
// CHECK-SAME: (memref<f32>, index, index, memref<f32>, index, index,
// CHECK-SAME: memref<f32>, index, index, index, index, index, i32, i32, i32, i32)
// CHECK-LABEL: func.func @mmt4d_ukernel(
@@ -111,7 +111,7 @@
// CHECK-DAG: %[[D5:.+]] = memref.dim %[[ARG0]], %[[C3]]
// CHECK-DAG: %[[D5_I32:.+]] = arith.index_cast %[[D5]]
// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
-// CHECK: call @vmvx.mmt4d.f32.f32.f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
+// CHECK: call @vmvx.mmt4d.f32f32f32(%[[BASE0]], %[[OFFSET0]], %[[STRIDES0]]#0
// CHECK-SAME: %[[BASE1]], %[[OFFSET1]], %[[STRIDES1]]#0
// CHECK-SAME: %[[BASE2]], %[[OFFSET2]], %[[STRIDES2]]#0
// CHECK-SAME: %[[D0]], %[[D1]], %[[D2]],
@@ -126,4 +126,4 @@
return
}
// CHECK-LABEL: func @mmt4d_ukernel_i8i8i32(
-// CHECK: call @vmvx.mmt4d.i8.i8.i32
+// CHECK: call @vmvx.mmt4d.i8i8i32
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
index a7288cb..9bde8a8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
@@ -50,6 +50,10 @@
rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
fnDecl = rewriter.create<func::FuncOp>(loc, fnName, functionType);
SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private);
+ // TODO(#12327): Based on description in the issue, add an attribute
+ // `vm.import.module` and set it to `vmvx`. This only works on `vmvx`
+ // backend (obviously), but is enough to unblock while the proper fix lands.
+ fnDecl->setAttr("vm.import.module", rewriter.getStringAttr("vmvx"));
} else if (fnDecl.getFunctionType() != functionType) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("mismatch in function type computed during lowering "
@@ -288,10 +292,10 @@
std::string fnName = "vmvx.mmt4d.";
switch (matmulType.value()) {
case MatmulType::I8I8I32:
- fnName.append("i8.i8.i32");
+ fnName.append("i8i8i32");
break;
case MatmulType::F32F32F32:
- fnName.append("f32.f32.f32");
+ fnName.append("f32f32f32");
break;
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
index cacbdf6..de030ae 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
@@ -66,10 +66,10 @@
std::string fnName = "";
switch (matmulType.value()) {
case MatmulType::I8I8I32:
- fnName = "vmvx.matmul.i8.i8.i32";
+ fnName = "vmvx.matmul.i8i8i32";
break;
case MatmulType::F32F32F32:
- fnName = "vmvx.matmul.f32.f32.f32";
+ fnName = "vmvx.matmul.f32f32f32";
break;
}
@@ -108,10 +108,10 @@
Type outElemType = outType.getElementType();
if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) &&
outElemType.isSignlessInteger(32)) {
- fnName = "vmvx.matmul.i8.i8.i32";
+ fnName = "vmvx.matmul.i8i8i32";
} else if (lhsElemType.isF32() && rhsElemType.isF32() &&
outElemType.isF32()) {
- fnName = "vmvx.matmul.f32.f32.f32";
+ fnName = "vmvx.matmul.f32f32f32";
}
if (fnName.empty()) {
return rewriter.notifyMatchFailure(op,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
index c28372d..ff9aa0c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
@@ -16,7 +16,7 @@
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32"
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
@@ -46,7 +46,7 @@
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32.f32.f32"
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.f32f32f32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-SAME: (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
@@ -70,7 +70,7 @@
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.i8.i8.i32"
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.matmul.i8i8i32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[D0]], %[[D1]], %[[D2]], %[[FLAGS]] :
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index a75a6e4..1304f2d 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -89,6 +89,50 @@
return strides;
}
+static FailureOr<DescriptorInfo> resolveBufferDescriptorForInterfaceBinding(
+ IREE::HAL::InterfaceBindingSubspanOp binding, RewriterBase &rewriter,
+ Location loc) {
+ auto memRefType = binding.getResult().getType().template cast<MemRefType>();
+ int rank = memRefType.getRank();
+ DescriptorInfo resultDescriptor;
+
+ // Compute sizes.
+ auto dynamicDimIt = binding.getDynamicDims().begin();
+ for (int i = 0; i < rank; ++i) {
+ if (memRefType.isDynamicDim(i)) {
+ resultDescriptor.sizes.push_back(*dynamicDimIt);
+ dynamicDimIt++;
+ } else {
+ resultDescriptor.sizes.push_back(
+ rewriter.getIndexAttr(memRefType.getDimSize(i)));
+ }
+ }
+ // Strides.
+ resultDescriptor.strides =
+ getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
+
+ // Offset.
+ Type elementType = memRefType.getElementType();
+ OpFoldResult elementWidth =
+ TypeSwitch<Type, OpFoldResult>(elementType)
+ .Case<ComplexType, IntegerType, FloatType>(
+ [&](auto type) -> OpFoldResult {
+ return rewriter.getIndexAttr(
+ IREE::Util::getRoundedElementByteWidth(
+ memRefType.getElementType()));
+ })
+ .Default([&](Type t) -> OpFoldResult {
+ return rewriter.create<IREE::Util::SizeOfOp>(loc, elementType)
+ .getResult();
+ });
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ resultDescriptor.offset = makeComposedFoldedAffineApply(
+ rewriter, loc, s0.floorDiv(s1),
+ ArrayRef<OpFoldResult>{binding.getByteOffset(), elementWidth});
+ return resultDescriptor;
+}
+
static FailureOr<DescriptorInfo> resolveBufferDescriptorForAllocation(
memref::AllocaOp alloca, RewriterBase &rewriter, Location loc) {
DescriptorInfo resultDescriptor;
@@ -260,34 +304,14 @@
if (!binding) return failure();
auto loc = op.getLoc();
-
- auto memRefType = binding.getResult().getType().cast<MemRefType>();
- int rank = memRefType.getRank();
- DescriptorInfo resultDescriptor;
-
- // Compute sizes.
- auto dynamicDimIt = binding.getDynamicDims().begin();
- for (int i = 0; i < rank; ++i) {
- if (memRefType.isDynamicDim(i)) {
- resultDescriptor.sizes.push_back(*dynamicDimIt);
- dynamicDimIt++;
- } else {
- resultDescriptor.sizes.push_back(
- rewriter.getIndexAttr(memRefType.getDimSize(i)));
- }
+ FailureOr<DescriptorInfo> resultDescriptor =
+ resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
+ if (failed(resultDescriptor)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to resolve descriptor with source being binding op");
}
- // Strides.
- resultDescriptor.strides =
- getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
-
- // Offset.
- auto elementSize =
- rewriter.create<IREE::Util::SizeOfOp>(loc, memRefType.getElementType());
- resultDescriptor.offset = rewriter.createOrFold<arith::DivUIOp>(
- loc, binding.getByteOffset(), elementSize);
-
- replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
+ replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer.
rewriter.replaceAllUsesWith(
@@ -316,32 +340,19 @@
if (memRefType.getRank() < 1) return failure();
auto loc = op.getLoc();
- int rank = memRefType.getRank();
- DescriptorInfo resultDescriptor;
-
- // Compute sizes.
- auto dynamicDimIt = binding.getDynamicDims().begin();
- for (int i = 0; i < rank; ++i) {
- if (memRefType.isDynamicDim(i)) {
- resultDescriptor.sizes.push_back(*dynamicDimIt);
- dynamicDimIt++;
- } else {
- resultDescriptor.sizes.push_back(
- rewriter.getIndexAttr(memRefType.getDimSize(i)));
- }
+ FailureOr<DescriptorInfo> resultDescriptor =
+ resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
+ if (failed(resultDescriptor)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to resolve descriptor with source being binding op");
}
- // Strides.
- resultDescriptor.strides =
- getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
- resultDescriptor.offset = rewriter.getIndexAttr(0);
-
- replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor);
+ replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer. Use a 1D memref for hal.interface.binding.subspan.
AffineMap mulMap = getMulMap(rewriter.getContext());
OpFoldResult linearizedMemrefSize = rewriter.getIndexAttr(1);
- for (auto size : resultDescriptor.sizes) {
+ for (auto size : resultDescriptor->sizes) {
linearizedMemrefSize = makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {linearizedMemrefSize, size});
}
@@ -349,13 +360,16 @@
SmallVector<Value> dynamicLinearShape;
dispatchIndexOpFoldResult(linearizedMemrefSize, dynamicLinearShape,
staticLinearShape);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(binding);
+
+ Value newOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value linearInterfaceBinding =
rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
loc, op.getBaseBuffer().getType(), binding.getSetAttr(),
binding.getBindingAttr(), binding.getDescriptorTypeAttr(),
- binding.getByteOffset(),
- /*dynamicDims =*/ValueRange{}, binding.getAlignmentAttr(),
- binding.getDescriptorFlagsAttr());
+ newOffset, /*dynamicDims =*/ValueRange{},
+ binding.getAlignmentAttr(), binding.getDescriptorFlagsAttr());
rewriter.replaceAllUsesWith(op.getBaseBuffer(), linearInterfaceBinding);
rewriter.eraseOp(op);
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index 8dc786e..7505745 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -81,12 +81,13 @@
%base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xindex> -> !util.buffer, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 floordiv s1)>
// CHECK: func @resolve_binding_subspan_offset_index(
// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index
// CHECK-DAG: %[[C384:.+]] = arith.constant 384 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[INDEX_SIZE:.+]] = util.sizeof index
-// CHECK-DAG: %[[OFFSET:.+]] = arith.divui %arg0, %[[INDEX_SIZE]] : index
+// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #map()[%arg0, %[[INDEX_SIZE]]]
// CHECK: %[[CAST:.+]] = vmvx.get_raw_interface_binding_buffer set(0) binding(0)
// CHECK: return %[[CAST]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
@@ -214,7 +215,6 @@
// CHECK: %[[SUB_OFFSET:.+]] = affine.apply #[[MAP]]()[%arg1, %arg2]
// CHECK: return %[[BASE_BUFFER]], %[[SUB_OFFSET]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
-
// -----
func.func @resolve_binding_subspan_zero_offset_memref() -> (memref<f32>, index, index, index, index, index) {
@@ -238,13 +238,16 @@
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<512x384xindex, strided<[384, 1], offset:?>> -> memref<index>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<index>, index, index, index, index, index
}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 floordiv s1)>
// CHECK: func @resolve_binding_subspan_offset_index_memref(
// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index
// CHECK-DAG: %[[C384:.+]] = arith.constant 384 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) : memref<index>
-// CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+// CHECK: %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%[[C0]]) : memref<index>
+// CHECK: %[[SIZEOF:.+]] = util.sizeof index
+// CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP]]()[%arg0, %[[SIZEOF]]]
+// CHECK: return %[[BINDING]], %[[OFFSET]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
// -----