iree-shape-materialize-calculations: add xla_hlo.dot_general
PiperOrigin-RevId: 303202635
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index 8a63bb5..df9d94d 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -17,7 +17,9 @@
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Optional.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
@@ -196,6 +198,85 @@
return builder.create<MakeRankedShapeOp>(loc, resultShape, dynamicDims);
}
+// Returns a value of type `!shapex.ranked_shape` for the input value.
+static Value getRankedShapeAsValue(Value v, OpBuilder &builder, Location loc) {
+ assert(v.getType().isa<TensorType>());
+ auto type = v.getType().dyn_cast<RankedTensorType>();
+ if (!type) {
+ return nullptr;
+ }
+ return builder.create<GetRankedShapeOp>(
+ loc, RankedShapeType::get(type.getShape(), builder.getContext()), v);
+}
+
+// Returns a value representing the extent of dimension `dim`.
+static Value getExtent(Value v, int64_t dim, OpBuilder &builder, Location loc) {
+ return builder.create<RankedDimOp>(loc, v, dim);
+}
+
+Value rewriteDotGeneral(RankedShapeType resultShape, DotGeneralOp op,
+ OpBuilder &builder) {
+ Location loc = op.getLoc();
+ auto lhsShape = getRankedShapeAsValue(op.lhs(), builder, loc);
+ auto rhsShape = getRankedShapeAsValue(op.rhs(), builder, loc);
+ if (!lhsShape || !rhsShape) {
+ return nullptr;
+ }
+ auto getFreeDims = [&](ArrayRef<int64_t> batchDims,
+ ArrayRef<int64_t> contractingDims, int64_t rank) {
+ llvm::BitVector freeDims(rank, true);
+ for (int64_t dim : batchDims) {
+ freeDims.reset(dim);
+ }
+ for (int64_t dim : contractingDims) {
+ freeDims.reset(dim);
+ }
+ SmallVector<int64_t, 4> result;
+ for (auto bitIndex : freeDims.set_bits()) {
+ result.push_back(bitIndex);
+ }
+ return result;
+ };
+ auto lhsRankedShape = lhsShape.getType().cast<RankedShapeType>();
+ auto rhsRankedShape = rhsShape.getType().cast<RankedShapeType>();
+ auto dotDimensions = op.dot_dimension_numbers();
+ auto lhsFreeDims = getFreeDims(
+ llvm::to_vector<4>(
+ dotDimensions.lhs_batching_dimensions().getValues<int64_t>()),
+ llvm::to_vector<4>(
+ dotDimensions.lhs_contracting_dimensions().getValues<int64_t>()),
+ lhsRankedShape.getRank());
+ auto rhsFreeDims = getFreeDims(
+ llvm::to_vector<4>(
+ dotDimensions.rhs_batching_dimensions().getValues<int64_t>()),
+ llvm::to_vector<4>(
+ dotDimensions.rhs_contracting_dimensions().getValues<int64_t>()),
+ rhsRankedShape.getRank());
+
+ SmallVector<Value, 6> outputExtents;
+ for (int64_t dim :
+ dotDimensions.lhs_batching_dimensions().getValues<int64_t>()) {
+ // TODO(silvasean): Add a version of MakeRankedShapeOp that takes
+ // all dimensions. Having to constantly check if a dim is dynamic
+ // upon construction is a waste of time, more testing burden, etc.
+ // We can easily canonicalize to the more constrained one.
+ if (lhsRankedShape.isDimDynamic(dim)) {
+ outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
+ }
+ }
+ for (int64_t dim : lhsFreeDims) {
+ if (lhsRankedShape.isDimDynamic(dim)) {
+ outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
+ }
+ }
+ for (int64_t dim : rhsFreeDims) {
+ if (rhsRankedShape.isDimDynamic(dim)) {
+ outputExtents.push_back(getExtent(rhsShape, dim, builder, loc));
+ }
+ }
+ return builder.create<MakeRankedShapeOp>(loc, resultShape, outputExtents);
+}
+
} // namespace
// Creates a custom op shape builder for XLA-HLO ops that are not otherwise
@@ -224,6 +305,7 @@
rewriteShapexRankedBroadcastInDim);
b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
+ b.insertOpRankedShapeBuilder<xla_hlo::DotGeneralOp>(rewriteDotGeneral);
}
} // namespace xla_hlo
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
index ee146f4..3fbb25d 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
@@ -15,3 +15,21 @@
// CHECK: return %[[RESULT]], %[[SHAPE]]
return %0, %1 : tensor<7x?x10xf32>, !shapex.ranked_shape<[7,?,10]>
}
+
+// CHECK-LABEL: func @dot_general
+func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> !shapex.ranked_shape<[?,?,?]> {
+ // Extents are taken directly from args.
+ // CHECK-DAG: %[[EXTENT0:.+]] = dim %arg0, 0
+ // CHECK-DAG: %[[EXTENT1:.+]] = dim %arg0, 1
+ // CHECK-DAG: %[[EXTENT2:.+]] = dim %arg1, 2
+ // CHECK-DAG: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[EXTENT0]], %[[EXTENT1]], %[[EXTENT2]]
+ // CHECK-DAG: return %[[SHAPE]]
+ %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<0> : tensor<1xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<0> : tensor<1xi64>,
+ rhs_contracting_dimensions = dense<1> : tensor<1xi64>
+ }} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
+ return %1 : !shapex.ranked_shape<[?,?,?]>
+}