Converting broadcast_in_dim to VMLA.

PiperOrigin-RevId: 294360884
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 1b19138..2843a87 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -65,6 +65,43 @@
   }
 };
 
+// Converts a broadcast_in_dim op to either a broadcast or a tile depending on
+// the input shape.
+struct BroadcastInDimOpConversion
+    : public OpConversionPattern<xla_hlo::BroadcastInDimOp> {
+  BroadcastInDimOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::BroadcastInDimOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto srcShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+
+    auto tensorType = srcOp.operand().getType().cast<TensorType>();
+    if (tensorType.getRank() == 0) {
+      // Broadcast of a scalar value.
+      rewriter.create<IREE::VMLA::BroadcastOp>(
+          srcOp.getLoc(), operands[0], srcShape, dst, dstShape,
+          TypeAttr::get(tensorType.getElementType()));
+    } else {
+      // Tiling a non-scalar value.
+      rewriter.create<IREE::VMLA::TileOp>(
+          srcOp.getLoc(), operands[0], srcShape, dst, dstShape,
+          TypeAttr::get(tensorType.getElementType()));
+    }
+
+    rewriter.replaceOp(srcOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 }  // namespace
 
 void populateHLOToVMLAPatterns(MLIRContext *context,
@@ -155,6 +192,10 @@
   patterns.insert<IdentityOpConversion<xla_hlo::BitcastConvertOp>>(context);
   patterns.insert<IdentityOpConversion<xla_hlo::ReshapeOp>>(context);
 
+  // Conversions that don't have a 1:1 mapping, mostly involving buffer views
+  // or transfers.
+  patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
+
   // TODO(benvanik): add missing ops:
   // - ConvOp
 }
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
new file mode 100644
index 0000000..fe5271c
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @broadcast_in_dim_2D_3D
+func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> {
+  // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[2,4],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+  // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
+  %input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
+  // CHECK-NEXT: "vmla.tile"([[SRC]], [[SRC_SHAPE]], [[DST]], [[DST_SHAPE]]) {element_type = i32}
+  %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+  // CHECK-NEXT: return [[DST]] : !iree.ref<!vmla.buffer>
+  return %0 : tensor<3x2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_in_dim_3D_scalar
+func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> {
+  // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+  // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
+  %input = constant dense<42> : tensor<i32>
+  // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
+  // CHECK-NEXT: "vmla.broadcast"([[SRC]], [[SRC_SHAPE]], [[DST]], [[DST_SHAPE]]) {element_type = i32}
+  %0 = "xla_hlo.broadcast_in_dim"(%input) : (tensor<i32>) -> tensor<3x2x4xi32>
+  // CHECK-NEXT: return [[DST]] : !iree.ref<!vmla.buffer>
+  return %0 : tensor<3x2x4xi32>
+}