Lowering HLO concat to VMLA copies.
PiperOrigin-RevId: 294364292
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 37e0eb3..e03e99f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -102,6 +102,58 @@
TypeConverter &typeConverter;
};
+// Converts a concat into a set of copies into the destination buffer.
+struct ConcatenateOpConversion
+ : public OpConversionPattern<xla_hlo::ConcatenateOp> {
+ ConcatenateOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::ConcatenateOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto indexType = rewriter.getIntegerType(32);
+ auto zero = rewriter.createOrFold<mlir::ConstantOp>(
+ srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+ auto finalType = srcOp.getResult().getType().cast<TensorType>();
+ int rank = finalType.getRank();
+ llvm::SmallVector<Value, 4> srcIndices(rank, zero);
+ llvm::SmallVector<Value, 4> dstIndices(rank, zero);
+ auto concatDimension = srcOp.dimension().getZExtValue();
+ for (auto srcDstOperand : llvm::zip(srcOp.val(), operands)) {
+ Value tensorOperand, bufferOperand;
+ std::tie(tensorOperand, bufferOperand) = srcDstOperand;
+
+ auto srcShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), tensorOperand, typeConverter, rewriter);
+ SmallVector<Value, 4> lengths(rank);
+ for (int i = 0; i < rank; ++i) {
+ lengths[i] = rewriter.createOrFold<Shape::RankedDimOp>(srcOp.getLoc(),
+ srcShape, i);
+ }
+
+ rewriter.create<IREE::VMLA::CopyOp>(
+ srcOp.getLoc(), bufferOperand, srcShape, srcIndices, dst, dstShape,
+ dstIndices, lengths,
+ TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
+
+ dstIndices[concatDimension] = rewriter.createOrFold<mlir::AddIOp>(
+ srcOp.getLoc(), dstIndices[concatDimension],
+ lengths[concatDimension]);
+ }
+
+ rewriter.replaceOp(srcOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
// Converts a static slice op to a copy (if the source must be preserved).
struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
@@ -302,6 +354,7 @@
// Conversions that don't have a 1:1 mapping, mostly involving buffer views
// or transfers.
patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+ patterns.insert<ConcatenateOpConversion>(context, typeConverter);
patterns.insert<SliceOpConversion>(context, typeConverter);
patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
new file mode 100644
index 0000000..6d386cf
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @concatenate_0
+func @concatenate_0() -> (tensor<2x5xi32>) {
+ // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+ %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+ // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c40_i32)
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c2_i32,
+ // CHECK-SAME: %c2_i32, %c3_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %0 = "xla_hlo.concatenate"(%c0, %c1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %0: tensor<2x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_1
+func @concatenate_1() -> (tensor<2x5xi32>) {
+ // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+ %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+ // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c40_i32)
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c3_i32
+ // CHECK-SAME: ) {element_type = i32}
+ // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c3_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %0 = "xla_hlo.concatenate"(%c1, %c0) {dimension = 1} : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %0: tensor<2x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_2
+func @concatenate_2() -> (tensor<2x7xi32>) {
+ // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x3xi32>
+ %c1 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
+ // CHECK-DAG: [[ARG2:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c2 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
+ // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c56_i32)
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c2_i32,
+ // CHECK-SAME: %c2_i32, %c3_i32
+ // CHECK-SAME: ) {element_type = i32}
+ // CHECK-NEXT: [[ARG2_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG2]], [[ARG2_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c5_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %0 = "xla_hlo.concatenate"(%c0, %c1, %c2) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %0: tensor<2x7xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concatenate_3
+func @concatenate_3() -> (tensor<4x2xi32>) {
+ // CHECK-DAG: [[ARG0:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK-DAG: [[ARG1:%.+]] = "vmla.constant"() {{.+}} tensor<2x2xi32>
+ %c2 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
+ // CHECK: [[DST:%.+]] = "vmla.buffer.alloc"(%c32_i32)
+ // CHECK-NEXT: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: [[ARG0_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG0]], [[ARG0_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ // CHECK-NEXT: [[ARG1_SHAPE:%.+]] = shapex.const_ranked_shape
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[ARG1]], [[ARG1_SHAPE]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c2_i32, %c0_i32,
+ // CHECK-SAME: %c2_i32, %c2_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %0 = "xla_hlo.concatenate"(%c0, %c2) {dimension = 0} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %0: tensor<4x2xi32>
+}