Lowering HLO slice to VMLA copies.

PiperOrigin-RevId: 294363411
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 2843a87..37e0eb3 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -102,6 +102,113 @@
   TypeConverter &typeConverter;
 };
 
+// Converts a static slice op to a copy (if the source must be preserved).
+struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
+  SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::SliceOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto isNotOne = [](APInt stride) { return stride != 1; };
+    if (llvm::any_of(srcOp.strides(), isNotOne)) {
+      srcOp.emitWarning()
+          << "Could not lower slice op with non-singular strides";
+      return matchFailure();
+    }
+
+    // TODO(benvanik): if the source is only used by this op then replace with
+    // a vmla.buffer.view op.
+
+    auto srcShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+    int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
+    auto indexType = rewriter.getIntegerType(32);
+    SmallVector<Value, 4> srcIndices(rank);
+    SmallVector<Value, 4> dstIndices(rank);
+    SmallVector<Value, 4> lengths(rank);
+    Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+    for (int i = 0; i < rank; ++i) {
+      uint64_t ui = static_cast<uint64_t>(i);
+      srcIndices[i] = rewriter.createOrFold<mlir::ConstantOp>(
+          srcOp.getLoc(), indexType,
+          rewriter.getI32IntegerAttr(
+              srcOp.start_indices().getValue<int64_t>({ui})));
+      dstIndices[i] = zero;
+      lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+          srcOp.getLoc(), indexType,
+          rewriter.getI32IntegerAttr(
+              srcOp.limit_indices().getValue<int64_t>({ui}) -
+              srcOp.start_indices().getValue<int64_t>({ui})));
+    }
+
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    rewriter.create<IREE::VMLA::CopyOp>(
+        srcOp.getLoc(), operands[0], srcShape, srcIndices, dst, dstShape,
+        dstIndices, lengths,
+        TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+    rewriter.replaceOp(srcOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
+// Converts a dynamic slice op to a copy (if the source must be preserved).
+struct DynamicSliceOpConversion
+    : public OpConversionPattern<xla_hlo::DynamicSliceOp> {
+  DynamicSliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::DynamicSliceOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    // TODO(benvanik): if the source is only used by this op then replace with
+    // a vmla.buffer.view op.
+
+    auto srcShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.result(), typeConverter, rewriter);
+
+    int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
+    auto indexType = rewriter.getIntegerType(32);
+    SmallVector<Value, 4> srcIndices(rank);
+    SmallVector<Value, 4> dstIndices(rank);
+    SmallVector<Value, 4> lengths(rank);
+    Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+    for (int i = 0; i < rank; ++i) {
+      srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
+          srcOp.getLoc(), rewriter.getIntegerType(32), operands[1],
+          rewriter.createOrFold<mlir::ConstantOp>(
+              srcOp.getLoc(), indexType,
+              rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
+      dstIndices[i] = zero;
+      lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+          srcOp.getLoc(), indexType,
+          rewriter.getI32IntegerAttr(srcOp.slice_sizes().getValue<int64_t>(
+              {static_cast<uint64_t>(i)})));
+    }
+
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    rewriter.create<IREE::VMLA::CopyOp>(
+        srcOp.getLoc(), operands[0], srcShape, srcIndices, dst, dstShape,
+        dstIndices, lengths,
+        TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+    rewriter.replaceOp(srcOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 }  // namespace
 
 void populateHLOToVMLAPatterns(MLIRContext *context,
@@ -195,6 +302,8 @@
   // Conversions that don't have a 1:1 mapping, mostly involving buffer views
   // or transfers.
   patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+  patterns.insert<SliceOpConversion>(context, typeConverter);
+  patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
 
   // TODO(benvanik): add missing ops:
   // - ConvOp
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
new file mode 100644
index 0000000..c9e32bc
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
@@ -0,0 +1,111 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @slice_whole_buffer
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_whole_buffer(%src_indices : tensor<2xi64>) -> tensor<3x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+  // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c48_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c3_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+    slice_sizes = dense<[3, 4]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<3x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_whole_stride
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_whole_stride(%src_indices : tensor<2xi64>) -> tensor<1x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+  // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c16_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+    slice_sizes = dense<[1, 4]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_stride_part
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_stride_part(%src_indices : tensor<2xi64>) -> tensor<1x2xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+  // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c8_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+    slice_sizes = dense<[1, 2]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x2xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_multi_stride
+// CHECK-SAME: [[SRC_INDICES:%.+]]: !iree.ref<!vmla.buffer>
+func @slice_multi_stride(%src_indices : tensor<2xi64>) -> tensor<2x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[SRC_INDEX_0:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c0_i32)
+  // CHECK-NEXT: [[SRC_INDEX_1:%.+]] = "vmla.buffer.load.i32"([[SRC_INDICES]], %c4_i32)
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[SRC_INDEX_0]], [[SRC_INDEX_1]],
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.dynamic-slice"(%input, %src_indices) {
+    slice_sizes = dense<[2, 4]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<2x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<2x4xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
new file mode 100644
index 0000000..a5becc7
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
@@ -0,0 +1,107 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @slice_whole_buffer
+func @slice_whole_buffer() -> tensor<3x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c48_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c3_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.slice"(%input) {
+    start_indices = dense<[0, 0]> : tensor<2xi64>,
+    limit_indices = dense<[3, 4]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<3x4xi32>) -> tensor<3x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_whole_stride
+func @slice_whole_stride() -> tensor<1x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c16_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.slice"(%input) {
+    start_indices = dense<[1, 0]> : tensor<2xi64>,
+    limit_indices = dense<[2, 4]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<3x4xi32>) -> tensor<1x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_stride_part
+func @slice_stride_part() -> tensor<1x2xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c8_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c1_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.slice"(%input) {
+    start_indices = dense<[1, 1]> : tensor<2xi64>,
+    limit_indices = dense<[2, 3]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_multi_stride
+func @slice_multi_stride() -> tensor<2x4xi32> {
+  // CHECK: [[SRC:%.+]] = "vmla.constant"()
+  %input = constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]
+  ]> : tensor<3x4xi32>
+  // CHECK-NEXT: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], %c1_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c4_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %result = "xla_hlo.slice"(%input) {
+    start_indices = dense<[1, 0]> : tensor<2xi64>,
+    limit_indices = dense<[3, 4]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<3x4xi32>) -> tensor<2x4xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %result : tensor<2x4xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 2bbeb9c..266acf6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -154,10 +154,12 @@
   VMLA_IMPORT_OP(IREE::VMLA::BufferViewOp, "vmla.buffer.view");
   VMLA_IMPORT_OP(IREE::VMLA::BufferCopyOp, "vmla.buffer.copy");
   VMLA_IMPORT_OP(IREE::VMLA::BufferFillOp, "vmla.buffer.fill");
+  VMLA_IMPORT_OP(IREE::VMLA::BufferLoadI32Op, "vmla.buffer.load.i32");
 
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
 
+  VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::ReverseOp, "vmla.reverse");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::PadOp, "vmla.pad");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLABase.td b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
index bf325ae..74b84ea 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLABase.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
@@ -72,6 +72,8 @@
 def VMLA_HostSize : TypeAlias<I32>;
 def VMLA_HostSizeAttr : IntegerAttrBase<I32, "size_t">;
 
+def VMLA_Index : TypeAlias<I32>;
+
 def VMLA_Shape : TypeAlias<Shape_RankedShape>;
 
 def VMLA_HostBufferRef : AnyTypeOf<[
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index cfaf98c..183a995 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -103,6 +103,16 @@
   );
 }
 
+def VMLA_BufferLoadI32Op : VMLA_PureOp<"buffer.load.i32"> {
+  let arguments = (ins
+    VMLA_BufferRef:$src,
+    VMLA_DeviceSize:$byte_offset
+  );
+  let results = (outs
+    I32:$result
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: comparison
 //===----------------------------------------------------------------------===//
@@ -131,6 +141,22 @@
 // VMLA Ops: shape/structure
 //===----------------------------------------------------------------------===//
 
+def VMLA_CopyOp : VMLA_ElementTypeOp<"copy", [
+    VMLA_IncludeShapes,
+    SameVariadicOperandSize,
+  ]> {
+  let arguments = (ins
+    VMLA_BufferRef:$src,
+    VMLA_Shape:$src_shape,
+    Variadic<VMLA_Index>:$src_indices,
+    VMLA_BufferRef:$dst,
+    VMLA_Shape:$dst_shape,
+    Variadic<VMLA_Index>:$dst_indices,
+    Variadic<VMLA_Index>:$lengths,
+    VMLA_AnyTypeAttr:$element_type
+  );
+}
+
 def VMLA_TransposeOp : VMLA_ElementTypeOp<"transpose", [VMLA_IncludeShapes]> {
   let arguments = (ins
     VMLA_BufferRef:$src,
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index d699ff5..cd06d1d 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -50,6 +50,12 @@
   %dst : !iree.ref<!vmla.buffer>
 )
 
+vm.import @buffer.load.i32(
+  %src : !iree.ref<!vmla.buffer>,
+  %byte_offset : i32
+) -> i32
+attributes {nosideeffects}
+
 vm.import @cmp.i8(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
 vm.import @cmp.i16(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
 vm.import @cmp.i32(%predicate : i32, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
@@ -59,6 +65,24 @@
 vm.import @select.x16(%cond : !iree.ref<!vmla.buffer>, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
 vm.import @select.x32(%cond : !iree.ref<!vmla.buffer>, %lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
 
+// TODO(benvanik): do the copies with buffer.copy instead and leave the offset
+// calculations in the IR for the compiler to simplify.
+vm.import @copy.x8(
+  %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+  %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+  %lengths : i32 ...
+)
+vm.import @copy.x16(
+  %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+  %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+  %lengths : i32 ...
+)
+vm.import @copy.x32(
+  %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
+  %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
+  %lengths : i32 ...
+)
+
 vm.import @transpose.x8(
   %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
   %dimensions : i32 ...,
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index b2deeda..04d2126 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -304,6 +304,15 @@
     return OkStatus();
   }
 
+  StatusOr<int32_t> BufferLoadI32(vm::ref<iree_vmla_buffer_t>& src,
+                                  iree_vmla_size_t byte_offset) {
+    IREE_TRACE_SCOPE0("VMLAModuleState::BufferLoadI32");
+    IREE_RETURN_IF_NULL(src);
+    ASSIGN_OR_RETURN(auto data,
+                     src->RangeAs<int32_t>(byte_offset, sizeof(int32_t)));
+    return data[0];
+  }
+
   //===--------------------------------------------------------------------===//
   // Common helpers for defining ops
   //===--------------------------------------------------------------------===//
@@ -397,6 +406,21 @@
   // VMLA Ops: shape/structure
   //===--------------------------------------------------------------------===//
 
+#define IREE_VMLA_COPY_OP(name, size)                                          \
+  Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape,   \
+              absl::Span<const int32_t> src_indices,                           \
+              vm::ref<iree_vmla_buffer_t>& dst, iree_vmla_shape_t dst_shape,   \
+              absl::Span<const int32_t> dst_indices,                           \
+              absl::Span<const int32_t> lengths) {                             \
+    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                              \
+    return kernels::Copy::Execute<size>(                                       \
+        src->As<uint8_t>(), Shape(src_shape), src_indices, dst->As<uint8_t>(), \
+        Shape(dst_shape), dst_indices, lengths);                               \
+  }
+  IREE_VMLA_COPY_OP(CopyX8, sizeof(uint8_t));
+  IREE_VMLA_COPY_OP(CopyX16, sizeof(uint16_t));
+  IREE_VMLA_COPY_OP(CopyX32, sizeof(uint32_t));
+
 #define IREE_VMLA_TRANSPOSE_OP(name, type)                                     \
   Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape,   \
               absl::Span<const int32_t> dims,                                  \
@@ -644,6 +668,7 @@
     vm::MakeNativeFunction("buffer.view", &VMLAModuleState::BufferView),
     vm::MakeNativeFunction("buffer.copy", &VMLAModuleState::BufferCopy),
     vm::MakeNativeFunction("buffer.fill", &VMLAModuleState::BufferFill),
+    vm::MakeNativeFunction("buffer.load.i32", &VMLAModuleState::BufferLoadI32),
 
     vm::MakeNativeFunction("cmp.i8", &VMLAModuleState::CmpI8),
     vm::MakeNativeFunction("cmp.i16", &VMLAModuleState::CmpI16),
@@ -653,9 +678,12 @@
     vm::MakeNativeFunction("select.x16", &VMLAModuleState::SelectX16),
     vm::MakeNativeFunction("select.x32", &VMLAModuleState::SelectX32),
 
-    vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
-    vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
-    vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),
+    vm::MakeNativeFunction("copy.x8", &VMLAModuleState::CopyX8),
+    vm::MakeNativeFunction("copy.x16", &VMLAModuleState::CopyX16),
+    vm::MakeNativeFunction("copy.x32", &VMLAModuleState::CopyX32),
+    vm::MakeNativeFunction("transpose.x8", &VMLAModuleState::TransposeX8),
+    vm::MakeNativeFunction("transpose.x16", &VMLAModuleState::TransposeX16),
+    vm::MakeNativeFunction("transpose.x32", &VMLAModuleState::TransposeX32),
     vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
     vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
     vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),