Converting broadcast_in_dim to VMLA.
PiperOrigin-RevId: 294360884
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 1b19138..2843a87 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -65,6 +65,43 @@
}
};
+// Converts a broadcast_in_dim op to either a broadcast or a tile depending on
+// the input shape.
+struct BroadcastInDimOpConversion
+ : public OpConversionPattern<xla_hlo::BroadcastInDimOp> {
+ BroadcastInDimOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::BroadcastInDimOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+ auto tensorType = srcOp.operand().getType().cast<TensorType>();
+ if (tensorType.getRank() == 0) {
+ // Broadcast of a scalar value.
+ rewriter.create<IREE::VMLA::BroadcastOp>(
+ srcOp.getLoc(), operands[0], srcShape, dst, dstShape,
+ TypeAttr::get(tensorType.getElementType()));
+ } else {
+ // Tiling a non-scalar value.
+ rewriter.create<IREE::VMLA::TileOp>(
+ srcOp.getLoc(), operands[0], srcShape, dst, dstShape,
+ TypeAttr::get(tensorType.getElementType()));
+ }
+
+ rewriter.replaceOp(srcOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
} // namespace
void populateHLOToVMLAPatterns(MLIRContext *context,
@@ -155,6 +192,10 @@
patterns.insert<IdentityOpConversion<xla_hlo::BitcastConvertOp>>(context);
patterns.insert<IdentityOpConversion<xla_hlo::ReshapeOp>>(context);
+ // Conversions that don't have a 1:1 mapping, mostly involving buffer views
+ // or transfers.
+ patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+
// TODO(benvanik): add missing ops:
// - ConvOp
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
new file mode 100644
index 0000000..fe5271c
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @broadcast_in_dim_2D_3D
+func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> {
+ // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[2,4],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+ // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
+ %input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
+ // CHECK-NEXT: "vmla.tile"([[SRC]], [[SRC_SHAPE]], [[DST]], [[DST_SHAPE]]) {element_type = i32}
+ %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ // CHECK-NEXT: return [[DST]] : !iree.ref<!vmla.buffer>
+ return %0 : tensor<3x2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_in_dim_3D_scalar
+func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> {
+ // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+ // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
+ %input = constant dense<42> : tensor<i32>
+ // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
+ // CHECK-NEXT: "vmla.broadcast"([[SRC]], [[SRC_SHAPE]], [[DST]], [[DST_SHAPE]]) {element_type = i32}
+ %0 = "xla_hlo.broadcast_in_dim"(%input) : (tensor<i32>) -> tensor<3x2x4xi32>
+ // CHECK-NEXT: return [[DST]] : !iree.ref<!vmla.buffer>
+ return %0 : tensor<3x2x4xi32>
+}