Lowering HLO concat to VMLA copies.

PiperOrigin-RevId: 294364292
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 37e0eb3..e03e99f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -102,6 +102,58 @@
   TypeConverter &typeConverter;
 };
 
+// Converts a concat into a set of copies into the destination buffer.
+struct ConcatenateOpConversion
+    : public OpConversionPattern<xla_hlo::ConcatenateOp> {
+  ConcatenateOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ConcatenateOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto indexType = rewriter.getIntegerType(32);
+    auto zero = rewriter.createOrFold<mlir::ConstantOp>(
+        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+    auto finalType = srcOp.getResult().getType().cast<TensorType>();
+    int rank = finalType.getRank();
+    llvm::SmallVector<Value, 4> srcIndices(rank, zero);
+    llvm::SmallVector<Value, 4> dstIndices(rank, zero);
+    auto concatDimension = srcOp.dimension().getZExtValue();
+    for (auto srcDstOperand : llvm::zip(srcOp.val(), operands)) {
+      Value tensorOperand, bufferOperand;
+      std::tie(tensorOperand, bufferOperand) = srcDstOperand;
+
+      auto srcShape = VMLAConversionTarget::getTensorShape(
+          srcOp.getLoc(), tensorOperand, typeConverter, rewriter);
+      SmallVector<Value, 4> lengths(rank);
+      for (int i = 0; i < rank; ++i) {
+        lengths[i] = rewriter.createOrFold<Shape::RankedDimOp>(srcOp.getLoc(),
+                                                               srcShape, i);
+      }
+
+      rewriter.create<IREE::VMLA::CopyOp>(
+          srcOp.getLoc(), bufferOperand, srcShape, srcIndices, dst, dstShape,
+          dstIndices, lengths,
+          TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+
+      dstIndices[concatDimension] = rewriter.createOrFold<mlir::AddIOp>(
+          srcOp.getLoc(), dstIndices[concatDimension],
+          lengths[concatDimension]);
+    }
+
+    rewriter.replaceOp(srcOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 // Converts a static slice op to a copy (if the source must be preserved).
 struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
   SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
@@ -302,6 +354,7 @@
   // Conversions that don't have a 1:1 mapping, mostly involving buffer views
   // or transfers.
   patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+  patterns.insert<ConcatenateOpConversion>(context, typeConverter);
   patterns.insert<SliceOpConversion>(context, typeConverter);
   patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
 
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
new file mode 100644
index 0000000..6d386cf
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @concatenate_0
+func @concatenate_0() -> (tensor<2x5xi32>) {
+  // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+  %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+  // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c40_i32)
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c2_i32,
+  // CHECK-SAME: %c2_i32, %c3_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %0 = "xla_hlo.concatenate"(%c0, %c1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %0: tensor<2x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_1
+func @concatenate_1() -> (tensor<2x5xi32>) {
+  // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+  %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+  // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c40_i32)
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c3_i32
+  // CHECK-SAME: ) {element_type = i32}
+  // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c3_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %0 = "xla_hlo.concatenate"(%c1, %c0) {dimension = 1} : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %0: tensor<2x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_2
+func @concatenate_2() -> (tensor<2x7xi32>) {
+  // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+  %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+  // CHECK-DAG: [[ARG2:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c2 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
+  // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c56_i32)
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c2_i32,
+  // CHECK-SAME: %c2_i32, %c3_i32
+  // CHECK-SAME: ) {element_type = i32}
+  // CHECK-NEXT: [[ARG2_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG2]], [[ARG2_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c5_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %0 = "xla_hlo.concatenate"(%c0, %c1, %c2) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %0: tensor<2x7xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_3
+func @concatenate_3() -> (tensor<4x2xi32>) {
+  // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+  %c2 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
+  // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+  // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c2_i32, %c0_i32,
+  // CHECK-SAME: %c2_i32, %c2_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %0 = "xla_hlo.concatenate"(%c0, %c2) {dimension = 0} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %0: tensor<4x2xi32>
+}