Lowering HLO slice to VMLA copies.
PiperOrigin-RevId: 294363411
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 2843a87..37e0eb3 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -102,6 +102,113 @@
TypeConverter &typeConverter;
};
+// Converts a static slice op to a copy (if the source must be preserved).
+struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
+ SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::SliceOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto isNotOne = [](APInt stride) { return stride != 1; };
+ if (llvm::any_of(srcOp.strides(), isNotOne)) {
+ srcOp.emitWarning()
+ << "Could not lower slice op with non-singular strides";
+ return matchFailure();
+ }
+
+ // TODO(benvanik): if the source is only used by this op then replace with
+ // a vmla.buffer.view op.
+
+ auto srcShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+ int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
+ auto indexType = rewriter.getIntegerType(32);
+ SmallVector<Value, 4> srcIndices(rank);
+ SmallVector<Value, 4> dstIndices(rank);
+ SmallVector<Value, 4> lengths(rank);
+ Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+ for (int i = 0; i < rank; ++i) {
+ uint64_t ui = static_cast<uint64_t>(i);
+ srcIndices[i] = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType,
+ rewriter.getI32IntegerAttr(
+ srcOp.start_indices().getValue<int64_t>({ui})));
+ dstIndices[i] = zero;
+ lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType,
+ rewriter.getI32IntegerAttr(
+ srcOp.limit_indices().getValue<int64_t>({ui}) -
+ srcOp.start_indices().getValue<int64_t>({ui})));
+ }
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ rewriter.create<IREE::VMLA::CopyOp>(
+ srcOp.getLoc(), operands[0], srcShape, srcIndices, dst, dstShape,
+ dstIndices, lengths,
+ TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+ rewriter.replaceOp(srcOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
+// Converts a dynamic slice op to a copy (if the source must be preserved).
+struct DynamicSliceOpConversion
+ : public OpConversionPattern<xla_hlo::DynamicSliceOp> {
+ DynamicSliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::DynamicSliceOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO(benvanik): if the source is only used by this op then replace with
+ // a vmla.buffer.view op.
+
+ auto srcShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.result(), typeConverter, rewriter);
+
+ int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
+ auto indexType = rewriter.getIntegerType(32);
+ SmallVector<Value, 4> srcIndices(rank);
+ SmallVector<Value, 4> dstIndices(rank);
+ SmallVector<Value, 4> lengths(rank);
+ Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+ for (int i = 0; i < rank; ++i) {
+ srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
+ srcOp.getLoc(), rewriter.getIntegerType(32), operands[1],
+ rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType,
+ rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
+ dstIndices[i] = zero;
+ lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType,
+ rewriter.getI32IntegerAttr(srcOp.slice_sizes().getValue<int64_t>(
+ {static_cast<uint64_t>(i)})));
+ }
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ rewriter.create<IREE::VMLA::CopyOp>(
+ srcOp.getLoc(), operands[0], srcShape, srcIndices, dst, dstShape,
+ dstIndices, lengths,
+ TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+ rewriter.replaceOp(srcOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
} // namespace
void populateHLOToVMLAPatterns(MLIRContext *context,
@@ -195,6 +302,8 @@
// Conversions that don't have a 1:1 mapping, mostly involving buffer views
// or transfers.
patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+ patterns.insert<SliceOpConversion>(context, typeConverter);
+ patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
// TODO(benvanik): add missing ops:
// - ConvOp
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
new file mode 100644
index 0000000..c9e32bc
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
@@ -0,0 +1,111 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @slice_whole_buffer
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_whole_buffer(%src_indices : tensor<2xi64>) -> tensor<3x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+ // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c48_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c3_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+ slice_sizes = dense<[3, 4]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<3x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_whole_stride
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_whole_stride(%src_indices : tensor<2xi64>) -> tensor<1x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+ // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c16_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+ slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_stride_part
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_stride_part(%src_indices : tensor<2xi64>) -> tensor<1x2xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+ // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c8_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+ slice_sizes = dense<[1, 2]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x2xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_multi_stride
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_multi_stride(%src_indices : tensor<2xi64>) -> tensor<2x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+ // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+ slice_sizes = dense<[2, 4]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<2x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<2x4xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
new file mode 100644
index 0000000..a5becc7
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
@@ -0,0 +1,107 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @slice_whole_buffer
+func @slice_whole_buffer() -> tensor<3x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c48_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c3_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.slice"(%input) {
+ start_indices = dense<[0, 0]> : tensor<2xi64>,
+ limit_indices = dense<[3, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<3x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_whole_stride
+func @slice_whole_stride() -> tensor<1x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c16_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.slice"(%input) {
+ start_indices = dense<[1, 0]> : tensor<2xi64>,
+ limit_indices = dense<[2, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_stride_part
+func @slice_stride_part() -> tensor<1x2xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c8_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c1_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.slice"(%input) {
+ start_indices = dense<[1, 1]> : tensor<2xi64>,
+ limit_indices = dense<[2, 3]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_multi_stride
+func @slice_multi_stride() -> tensor<2x4xi32> {
+ // CHECK: [[SRC:%.+]] = "vmla.constant"()
+ %input = constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]
+ ]> : tensor<3x4xi32>
+ // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c4_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %result = "xla_hlo.slice"(%input) {
+ start_indices = dense<[1, 0]> : tensor<2xi64>,
+ limit_indices = dense<[3, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<2x4xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %result : tensor<2x4xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 2bbeb9c..266acf6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -154,10 +154,12 @@
VMLA_IMPORT_OP(IREE::VMLA::BufferViewOp, "vmla.buffer.view");
VMLA_IMPORT_OP(IREE::VMLA::BufferCopyOp, "vmla.buffer.copy");
VMLA_IMPORT_OP(IREE::VMLA::BufferFillOp, "vmla.buffer.fill");
+ VMLA_IMPORT_OP(IREE::VMLA::BufferLoadI32Op, "vmla.buffer.load.i32");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
+ VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ReverseOp, "vmla.reverse");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::PadOp, "vmla.pad");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLABase.td b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
index bf325ae..74b84ea 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLABase.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
@@ -72,6 +72,8 @@
def VMLA_HostSize : TypeAlias<I32>;
def VMLA_HostSizeAttr : IntegerAttrBase<I32, "size_t">;
+def VMLA_Index : TypeAlias<I32>;
+
def VMLA_Shape : TypeAlias<Shape_RankedShape>;
def VMLA_HostBufferRef : AnyTypeOf<[
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index cfaf98c..183a995 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -103,6 +103,16 @@
);
}
+def VMLA_BufferLoadI32Op : VMLA_PureOp<"buffer.load.i32"> {
+ let arguments = (ins
+ VMLA_BufferRef:$src,
+ VMLA_DeviceSize:$byte_offset
+ );
+ let results = (outs
+ I32:$result
+ );
+}
+
//===----------------------------------------------------------------------===//
// VMLA Ops: comparison
//===----------------------------------------------------------------------===//
@@ -131,6 +141,22 @@
// VMLA Ops: shape/structure
//===----------------------------------------------------------------------===//
+def VMLA_CopyOp : VMLA_ElementTypeOp<"copy", [
+ VMLA_IncludeShapes,
+ SameVariadicOperandSize,
+ ]> {
+ let arguments = (ins
+ VMLA_BufferRef:$src,
+ VMLA_Shape:$src_shape,
+ Variadic<VMLA_Index>:$src_indices,
+ VMLA_BufferRef:$dst,
+ VMLA_Shape:$dst_shape,
+ Variadic<VMLA_Index>:$dst_indices,
+ Variadic<VMLA_Index>:$lengths,
+ VMLA_AnyTypeAttr:$element_type
+ );
+}
+
def VMLA_TransposeOp : VMLA_ElementTypeOp<"transpose", [VMLA_IncludeShapes]> {
let arguments = (ins
VMLA_BufferRef:$src,
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index d699ff5..cd06d1d 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -50,6 +50,12 @@
%dst : !iree.ref<!vmla.buffer>
)
+vm.import @buffer.load.i32(
+ %src : !iree.ref<!vmla.buffer>,
+ %byte_offset : i32
+) -> i32
+attributes {nosideeffects}
+
vm.import @cmp.i8(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @cmp.i16(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @cmp.i32(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
@@ -59,6 +65,24 @@
vm.import @select.x16(%cond : !iree.ref<!vmla.buffer>, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @select.x32(%cond : !iree.ref<!vmla.buffer>, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
+// TODO(benvanik): do the copies with buffer.copy instead and leave the offset
+// calculations in the IR for the compiler to simplify.
+vm.import @copy.x8(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+ %lengths : i32 ...
+)
+vm.import @copy.x16(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+ %lengths : i32 ...
+)
+vm.import @copy.x32(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+ %lengths : i32 ...
+)
+
vm.import @transpose.x8(
%src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
%dimensions : i32 ...,
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index b2deeda..04d2126 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -304,6 +304,15 @@
return OkStatus();
}
+ StatusOr<int32_t> BufferLoadI32(vm::ref<iree_vmla_buffer_t>& src,
+ iree_vmla_size_t byte_offset) {
+ IREE_TRACE_SCOPE0("VMLAModuleState::BufferLoadI32");
+ IREE_RETURN_IF_NULL(src);
+ ASSIGN_OR_RETURN(auto data,
+ src->RangeAs<int32_t>(byte_offset, sizeof(int32_t)));
+ return data[0];
+ }
+
//===--------------------------------------------------------------------===//
// Common helpers for defining ops
//===--------------------------------------------------------------------===//
@@ -397,6 +406,21 @@
// VMLA Ops: shape/structure
//===--------------------------------------------------------------------===//
+#define IREE_VMLA_COPY_OP(name, size) \
+ Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape, \
+ absl::Span<const int32_t> src_indices, \
+ vm::ref<iree_vmla_buffer_t>& dst, iree_vmla_shape_t dst_shape, \
+ absl::Span<const int32_t> dst_indices, \
+ absl::Span<const int32_t> lengths) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernels::Copy::Execute<size>( \
+ src->As<uint8_t>(), Shape(src_shape), src_indices, dst->As<uint8_t>(), \
+ Shape(dst_shape), dst_indices, lengths); \
+ }
+ IREE_VMLA_COPY_OP(CopyX8, sizeof(uint8_t));
+ IREE_VMLA_COPY_OP(CopyX16, sizeof(uint16_t));
+ IREE_VMLA_COPY_OP(CopyX32, sizeof(uint32_t));
+
#define IREE_VMLA_TRANSPOSE_OP(name, type) \
Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape, \
absl::Span<const int32_t> dims, \
@@ -644,6 +668,7 @@
vm::MakeNativeFunction("buffer.view", &VMLAModuleState::BufferView),
vm::MakeNativeFunction("buffer.copy", &VMLAModuleState::BufferCopy),
vm::MakeNativeFunction("buffer.fill", &VMLAModuleState::BufferFill),
+ vm::MakeNativeFunction("buffer.load.i32", &VMLAModuleState::BufferLoadI32),
vm::MakeNativeFunction("cmp.i8", &VMLAModuleState::CmpI8),
vm::MakeNativeFunction("cmp.i16", &VMLAModuleState::CmpI16),
@@ -653,9 +678,12 @@
vm::MakeNativeFunction("select.x16", &VMLAModuleState::SelectX16),
vm::MakeNativeFunction("select.x32", &VMLAModuleState::SelectX32),
- vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
- vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
- vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),
+ vm::MakeNativeFunction("copy.x8", &VMLAModuleState::CopyX8),
+ vm::MakeNativeFunction("copy.x16", &VMLAModuleState::CopyX16),
+ vm::MakeNativeFunction("copy.x32", &VMLAModuleState::CopyX32),
+ vm::MakeNativeFunction("transpose.x8", &VMLAModuleState::TransposeX8),
+ vm::MakeNativeFunction("transpose.x16", &VMLAModuleState::TransposeX16),
+ vm::MakeNativeFunction("transpose.x32", &VMLAModuleState::TransposeX32),
vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),