Handle memref<index> in VM conversion. (#9897)
* Introduces a util.sizeof op that allows the actual concrete size of a type to remain unknown until final conversion to the target.
* Adds a VM conversion for `util.sizeof index` to convert to the actual index size of the target.
* Adapts util.buffer->vm.buffer lowerings to properly size load/store/fill of index buffers based on the index size of the target.
* Enables remainder of tosa_ops tests for microkernels.
A followon will break out the memref->util conversion to be done as a final stage of vmvx lowering.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
index e74bf2f8..51b9504 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
@@ -52,8 +52,8 @@
// may elect to represent all i8 registers as i32, but this does not mean
// that all memrefs are widened from i8 to i32).
auto elementType = memrefType.getElementType();
- auto elementSize = rewriter.createOrFold<arith::ConstantIndexOp>(
- loc, IREE::Util::getRoundedElementByteWidth(elementType));
+ auto elementSize =
+ rewriter.createOrFold<IREE::Util::SizeOfOp>(loc, elementType);
// Rank 1 memrefs are just offset by their element width by the offset.
auto elementCount = indices.front();
@@ -155,11 +155,12 @@
return rewriter.notifyMatchFailure(
allocaOp, "unable to create buffers for dynamic shapes");
}
- int64_t memRefLength =
- type.getNumElements() *
- IREE::Util::getRoundedElementByteWidth(type.getElementType());
- Value allocationSize = rewriter.create<arith::ConstantIndexOp>(
- allocaOp.getLoc(), memRefLength);
+ auto numElements = rewriter.create<arith::ConstantIndexOp>(
+ allocaOp.getLoc(), type.getNumElements());
+ auto elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+ allocaOp.getLoc(), type.getElementType());
+ auto allocationSize = rewriter.createOrFold<arith::MulIOp>(
+ allocaOp.getLoc(), numElements, elementSize);
rewriter.replaceOpWithNewOp<IREE::Util::BufferAllocOp>(
allocaOp, rewriter.getType<IREE::Util::BufferType>(), allocationSize);
return success();
@@ -177,8 +178,8 @@
}
auto elementType =
dimOp.getSource().getType().cast<MemRefType>().getElementType();
- Value elementSize = rewriter.create<arith::ConstantIndexOp>(
- dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(elementType));
+ Value elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+ dimOp.getLoc(), elementType);
Value bufferSize = rewriter.create<IREE::Util::BufferSizeOp>(
dimOp.getLoc(), rewriter.getIndexType(), adaptor.getSource());
rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(dimOp, bufferSize,
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
index fa7a099..2d94678 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -60,6 +60,19 @@
}
// -----
+// CHECK-LABEL: @alloc_index
+// CHECK-SAME: (%[[IDX0:.+]]: index) -> !util.buffer {
+func.func @alloc_index(%idx0: index) -> memref<4xindex> {
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[SIZEOF:.*]] = util.sizeof index
+ // CHECK: %[[SZ:.*]] = arith.muli %[[SIZEOF]], %[[C4]]
+ // CHECK: %[[BUFFER:.*]] = util.buffer.alloc uninitialized : !util.buffer{%[[SZ]]}
+ %0 = memref.alloca() : memref<4xindex>
+ // CHECK: return %[[BUFFER]]
+ return %0 : memref<4xindex>
+}
+
+// -----
// CHECK-LABEL: @load_store_f32
// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index) -> f32 {
func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
@@ -113,6 +126,22 @@
}
// -----
+// CHECK-LABEL: @load_store_index
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[VALUE:.+]]: index) -> index {
+func.func @load_store_index(%buffer: memref<?xindex>, %idx0: index, %idx1: index, %value: index) -> index {
+ // CHECK-DAG: %[[SIZEOF:.*]] = util.sizeof index
+ // CHECK-DAG: %[[SZ:.*]] = util.buffer.size %[[BUFFER]]
+ // CHECK-DAG: %[[OFS0:.*]] = arith.muli %[[SIZEOF]], %[[IDX0]] : index
+ // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFS0]]] : index -> !util.buffer{%[[SZ]]}
+ memref.store %value, %buffer[%idx0] : memref<?xindex>
+ // CHECK: %[[OFS1:.*]] = arith.muli %[[SIZEOF]], %[[IDX1]] : index
+ // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]]] : !util.buffer{%[[SZ]]} -> index
+ %1 = memref.load %buffer[%idx1] : memref<?xindex>
+ // CHECK: return %[[LD]]
+ return %1 : index
+}
+
+// -----
// CHECK-LABEL: @dim_i16
// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index) -> index {
func.func @dim_i16(%buffer: memref<?xi16>, %idx0: index) -> index {
@@ -123,3 +152,15 @@
// CHECK: return %[[DV]]
return %0 : index
}
+
+// -----
+// CHECK-LABEL: @dim_index
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index) -> index {
+func.func @dim_index(%buffer: memref<?xindex>, %idx0: index) -> index {
+ // CHECK: %[[SIZEOF:.*]] = util.sizeof index
+ // CHECK: %[[SZ:.*]] = util.buffer.size %[[BUFFER]] : !util.buffer
+ // CHECK: %[[DV:.*]] = arith.floordivsi %[[SZ]], %[[SIZEOF]] : index
+ %0 = memref.dim %buffer, %idx0 : memref<?xindex>
+ // CHECK: return %[[DV]]
+ return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
index a425d03..d3bd131 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
@@ -37,7 +37,7 @@
def Util_Tensor : TensorOf<[Util_Element]>;
def Util_Primitive : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
-def Util_FillPattern : AnyTypeOf<[AnyInteger, AnyFloat]>;
+def Util_FillPattern : AnyTypeOf<[AnyInteger, AnyFloat, Index]>;
def Util_Offset : TypeAlias<Index>;
def Util_Size : TypeAlias<Index>;
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index ab7721c..59b2c8d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -391,6 +391,19 @@
}
//===----------------------------------------------------------------------===//
+// util.sizeof
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SizeOfOp::fold(ArrayRef<Attribute> operands) {
+ Type t = getSizedType();
+ if (t.isa<IntegerType>() || t.isa<FloatType>()) {
+ return IntegerAttr::get(IndexType::get(getContext()),
+ getRoundedElementByteWidth(t));
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
// Compiler hints
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 26815e2..14df90e 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -273,6 +273,33 @@
let hasFolder = 1;
}
+def Util_SizeOfOp : Util_PureOp<"sizeof"> {
+ let summary = [{returns the size in bytes of a datatype}];
+ let description = [{
+ Most datatypes have a static size at all layers of the compilation stack.
+ However, those that only have a size for certain lowering flows can be
+ challenging. This op represents such sizes in a way that can be specialized
+ later.
+
+ Returns the size in bytes, rounded up to the next whole byte of the
+ specified type. This op will fold to a constant index value for IntegerType
+ and FloatType. All others are not folded.
+ }];
+
+ let arguments = (ins
+ TypeAttr:$sizedType
+ );
+ let results = (outs
+ Index:$size
+ );
+
+ let assemblyFormat = [{
+ $sizedType attr-dict-with-keyword
+ }];
+
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Compiler hints
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
index 0de88a4..23561af 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
@@ -91,3 +91,30 @@
// CHECK: return %[[SUM_ALIGNED]]
return %result : index
}
+
+// -----
+
+// CHECK-LABEL: @sizeofWholeInt
+func.func @sizeofWholeInt() -> index {
+ // CHECK: = arith.constant 4 : index
+ %0 = util.sizeof i32
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @sizeofSubByteInt
+func.func @sizeofSubByteInt() -> index {
+ // CHECK: = arith.constant 2 : index
+ %0 = util.sizeof i12
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @sizeofFloat
+func.func @sizeofFloat() -> index {
+ // CHECK: = arith.constant 4 : index
+ %0 = util.sizeof f32
+ return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
index 9f68f6d..8bf52f8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_ops.mlir
@@ -15,3 +15,12 @@
%result = util.align %arg0, %arg1 : i32
return
}
+
+// -----
+
+// CHECK-LABEL: @sizeofUnfoldable
+func.func @sizeofUnfoldable() -> index {
+ // CHECK: = util.sizeof index
+ %0 = util.sizeof index
+ return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
index b622fc8..d2b4c05 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertAlignmentOps.cpp
@@ -68,6 +68,34 @@
}
};
+//===----------------------------------------------------------------------===//
+// util.sizeof
+//===----------------------------------------------------------------------===//
+
+/// For a `sizeof index` operation, invokes the type converter to derive the
+/// concrete type for index and rewrites to that. This allows us to do late
+/// resolution of the size of the index type at the point of conversion to VM
+/// where it is known.
+struct FixateIndexSizeofConversion
+ : public OpConversionPattern<IREE::Util::SizeOfOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::SizeOfOp sizeofOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type sizedType = sizeofOp.getSizedType();
+ if (sizedType.isa<IndexType>()) {
+ Type converted = getTypeConverter()->convertType(sizedType);
+ if (converted) {
+ Value newSizeof = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+ sizeofOp.getLoc(), converted);
+ rewriter.replaceOp(sizeofOp, newSizeof);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
} // namespace
void populateUtilAlignmentToVMPatterns(MLIRContext *context,
@@ -75,8 +103,10 @@
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
conversionTarget.addIllegalOp<IREE::Util::AlignOp>();
+ conversionTarget.addIllegalOp<IREE::Util::SizeOfOp>();
- patterns.insert<AlignOpConversion>(typeConverter, context);
+ patterns.insert<AlignOpConversion, FixateIndexSizeofConversion>(typeConverter,
+ context);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
index 3c50f36..9f13d7d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
@@ -154,6 +154,11 @@
IREE::Util::BufferFillOp fillOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = fillOp.getPattern().getType();
+ auto newType = adaptor.getPattern().getType();
+ if (oldType.isa<IndexType>()) {
+ // Use the actual converted type for IndexType.
+ oldType = newType;
+ }
auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
auto byteLength = castToI64(adaptor.getLength(), rewriter);
auto pattern = adaptor.getPattern();
@@ -196,6 +201,9 @@
ConversionPatternRewriter &rewriter) const override {
auto oldType = loadOp.getResult().getType();
auto newType = getTypeConverter()->convertType(oldType);
+ if (oldType.isa<IndexType>()) {
+ oldType = newType;
+ }
auto byteOffset = castToI64(adaptor.getSourceOffset(), rewriter);
if (auto integerType = oldType.dyn_cast<IntegerType>()) {
if (integerType.isInteger(1) || integerType.isInteger(8)) {
@@ -245,6 +253,10 @@
IREE::Util::BufferStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = storeOp.getSource().getType();
+ auto newType = adaptor.getSource().getType();
+ if (oldType.isa<IndexType>()) {
+ oldType = newType;
+ }
auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
if (oldType.isInteger(1) || oldType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
index 43e86c0..d157967 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir
@@ -42,3 +42,13 @@
//CHECK-DAG: vm.return %3 : i64
return %result : i64
}
+
+// -----
+
+// CHECK-LABEL: @utilSizeOfIndex
+func.func @utilSizeOfIndex() -> (index) {
+ // CHECK: %[[SIZEOF:.*]] = vm.const.i32 4
+ %0 = util.sizeof index
+ // CHECK: vm.return %[[SIZEOF]]
+ return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
index bfd2936..a050de9 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
@@ -150,6 +150,18 @@
// -----
+// CHECK-LABEL: @buffer_fill_index
+func.func @buffer_fill_index(%arg0: !util.buffer, %arg1: index, %arg2: index) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32: vm.buffer.fill.i32
+ // CHECK-64: vm.buffer.fill.i64
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : index -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
// CHECK-LABEL: @buffer_load_i1
func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 {
%c100 = arith.constant 100 : index
@@ -189,6 +201,17 @@
// -----
+// CHECK-LABEL: @buffer_load_index
+func.func @buffer_load_index(%arg0: !util.buffer, %arg1: index) -> index {
+ %c100 = arith.constant 100 : index
+ // CHECK-32: vm.buffer.load.i32
+ // CHECK-64: vm.buffer.load.i64
+ %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> index
+ return %0 : index
+}
+
+// -----
+
// CHECK-LABEL: @buffer_store_i1
func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) {
%c100 = arith.constant 100 : index
@@ -222,3 +245,14 @@
util.buffer.store %arg2, %arg0[%c100] : i64 -> !util.buffer{%arg1}
return
}
+
+// -----
+
+// CHECK-LABEL: @buffer_store_index
+func.func @buffer_store_index(%arg0: !util.buffer, %arg1: index, %arg2: index) {
+ %c100 = arith.constant 100 : index
+ // CHECK-32: vm.buffer.store.i32
+ // CHECK-64: vm.buffer.store.i64
+ util.buffer.store %arg2, %arg0[%c100] : index -> !util.buffer{%arg1}
+ return
+}
diff --git a/tests/e2e/tosa_ops/BUILD b/tests/e2e/tosa_ops/BUILD
index 93fc1a3..c380539 100644
--- a/tests/e2e/tosa_ops/BUILD
+++ b/tests/e2e/tosa_ops/BUILD
@@ -147,6 +147,7 @@
"exp.mlir",
"floor.mlir",
"fully_connected.mlir",
+ "gather.mlir",
"greater.mlir",
"greater_equal.mlir",
"if.mlir",
@@ -168,16 +169,12 @@
"select.mlir",
"sigmoid.mlir",
"sub.mlir",
+ "table.mlir",
"tanh.mlir",
"transpose.mlir",
"while.mlir",
],
include = ["*.mlir"],
- exclude = [
- # Decompositions produce tensor<index> which is not handled properly.
- "gather.mlir",
- "table.mlir",
- ],
)
iree_check_single_backend_test_suite(
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 853fd13..11695d1 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -131,6 +131,7 @@
"exp.mlir"
"floor.mlir"
"fully_connected.mlir"
+ "gather.mlir"
"greater.mlir"
"greater_equal.mlir"
"if.mlir"
@@ -152,6 +153,7 @@
"select.mlir"
"sigmoid.mlir"
"sub.mlir"
+ "table.mlir"
"tanh.mlir"
"transpose.mlir"
"while.mlir"