iree-shape-materialize-calculations: add xla_hlo.dot_general

PiperOrigin-RevId: 303202635
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index 8a63bb5..df9d94d 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -17,7 +17,9 @@
 #include "iree/compiler/Dialect/Shape/IR/Builders.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/Optional.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Value.h"
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 
@@ -196,6 +198,85 @@
   return builder.create<MakeRankedShapeOp>(loc, resultShape, dynamicDims);
 }
 
+// Returns a value of type `!shapex.ranked_shape` for the input value.
+static Value getRankedShapeAsValue(Value v, OpBuilder &builder, Location loc) {
+  assert(v.getType().isa<TensorType>());
+  auto type = v.getType().dyn_cast<RankedTensorType>();
+  if (!type) {
+    return nullptr;
+  }
+  return builder.create<GetRankedShapeOp>(
+      loc, RankedShapeType::get(type.getShape(), builder.getContext()), v);
+}
+
+// Returns a value representing the extent of dimension `dim`.
+static Value getExtent(Value v, int64_t dim, OpBuilder &builder, Location loc) {
+  return builder.create<RankedDimOp>(loc, v, dim);
+}
+
+Value rewriteDotGeneral(RankedShapeType resultShape, DotGeneralOp op,
+                        OpBuilder &builder) {
+  Location loc = op.getLoc();
+  auto lhsShape = getRankedShapeAsValue(op.lhs(), builder, loc);
+  auto rhsShape = getRankedShapeAsValue(op.rhs(), builder, loc);
+  if (!lhsShape || !rhsShape) {
+    return nullptr;
+  }
+  auto getFreeDims = [&](ArrayRef<int64_t> batchDims,
+                         ArrayRef<int64_t> contractingDims, int64_t rank) {
+    llvm::BitVector freeDims(rank, true);
+    for (int64_t dim : batchDims) {
+      freeDims.reset(dim);
+    }
+    for (int64_t dim : contractingDims) {
+      freeDims.reset(dim);
+    }
+    SmallVector<int64_t, 4> result;
+    for (auto bitIndex : freeDims.set_bits()) {
+      result.push_back(bitIndex);
+    }
+    return result;
+  };
+  auto lhsRankedShape = lhsShape.getType().cast<RankedShapeType>();
+  auto rhsRankedShape = rhsShape.getType().cast<RankedShapeType>();
+  auto dotDimensions = op.dot_dimension_numbers();
+  auto lhsFreeDims = getFreeDims(
+      llvm::to_vector<4>(
+          dotDimensions.lhs_batching_dimensions().getValues<int64_t>()),
+      llvm::to_vector<4>(
+          dotDimensions.lhs_contracting_dimensions().getValues<int64_t>()),
+      lhsRankedShape.getRank());
+  auto rhsFreeDims = getFreeDims(
+      llvm::to_vector<4>(
+          dotDimensions.rhs_batching_dimensions().getValues<int64_t>()),
+      llvm::to_vector<4>(
+          dotDimensions.rhs_contracting_dimensions().getValues<int64_t>()),
+      rhsRankedShape.getRank());
+
+  SmallVector<Value, 6> outputExtents;
+  for (int64_t dim :
+       dotDimensions.lhs_batching_dimensions().getValues<int64_t>()) {
+    // TODO(silvasean): Add a version of MakeRankedShapeOp that takes
+    // all dimensions. Having to constantly check if a dim is dynamic
+    // upon construction is a waste of time, more testing burden, etc.
+    // We can easily canonicalize to the more constrained one.
+    if (lhsRankedShape.isDimDynamic(dim)) {
+      outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
+    }
+  }
+  for (int64_t dim : lhsFreeDims) {
+    if (lhsRankedShape.isDimDynamic(dim)) {
+      outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
+    }
+  }
+  for (int64_t dim : rhsFreeDims) {
+    if (rhsRankedShape.isDimDynamic(dim)) {
+      outputExtents.push_back(getExtent(rhsShape, dim, builder, loc));
+    }
+  }
+  return builder.create<MakeRankedShapeOp>(loc, resultShape, outputExtents);
+}
+
 }  // namespace
 
 // Creates a custom op shape builder for XLA-HLO ops that are not otherwise
@@ -224,6 +305,7 @@
       rewriteShapexRankedBroadcastInDim);
   b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
   b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
+  b.insertOpRankedShapeBuilder<xla_hlo::DotGeneralOp>(rewriteDotGeneral);
 }
 
 }  // namespace xla_hlo
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
index ee146f4..3fbb25d 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
@@ -15,3 +15,21 @@
   // CHECK: return %[[RESULT]], %[[SHAPE]]
   return %0, %1 : tensor<7x?x10xf32>, !shapex.ranked_shape<[7,?,10]>
 }
+
+// CHECK-LABEL: func @dot_general
+func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> !shapex.ranked_shape<[?,?,?]> {
+  // Extents are taken directly from args.
+  // CHECK-DAG: %[[EXTENT0:.+]] = dim %arg0, 0
+  // CHECK-DAG: %[[EXTENT1:.+]] = dim %arg0, 1
+  // CHECK-DAG: %[[EXTENT2:.+]] = dim %arg1, 2
+  // CHECK-DAG: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[EXTENT0]], %[[EXTENT1]], %[[EXTENT2]]
+  // CHECK-DAG: return %[[SHAPE]]
+  %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
+    lhs_batching_dimensions = dense<0> : tensor<1xi64>,
+    lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+    rhs_batching_dimensions = dense<0> : tensor<1xi64>,
+    rhs_contracting_dimensions = dense<1> : tensor<1xi64>
+  }} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
+  return %1 : !shapex.ranked_shape<[?,?,?]>
+}