Add a utility to find or build a ranked shape given an arbitrary value.

(note that this is part of a larger chain, being broken up for submission and tests are included in subsequent changes)

PiperOrigin-RevId: 299399690
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp
index a91fed2..9ce8bab 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp
@@ -16,11 +16,34 @@
 
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir/IR/Diagnostics.h"
 
 namespace mlir {
 namespace iree_compiler {
 namespace Shape {
 
+namespace {
+
+Value getRankedShapeFromOp(Operation *op) {
+  auto tieOp = llvm::dyn_cast_or_null<TieShapeOp>(op);
+  if (!tieOp) return nullptr;
+  auto shape = tieOp.shape();
+  if (!shape.getType().isa<RankedShapeType>()) return nullptr;
+  return shape;
+}
+
+Value findRankedShapeFromUse(Value value) {
+  Value rs = getRankedShapeFromOp(value.getDefiningOp());
+  if (rs) return rs;
+  for (auto &use : value.getUses()) {
+    rs = getRankedShapeFromOp(use.getOwner());
+    if (rs) return rs;
+  }
+  return nullptr;
+}
+
+}  // namespace
+
 Value buildCastInputsToResultShape(Location loc,
                                    RankedShapeType resultShapeType,
                                    ArrayRef<Value> inputs, OpBuilder &builder) {
@@ -98,6 +121,49 @@
   }
 }
 
+Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType,
+                                     OpBuilder &builder) {
+  if (!dimType) dimType = builder.getIndexType();
+  auto valueSt = value.getType().dyn_cast<ShapedType>();
+  if (!valueSt) {
+    builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+        << "cannot construct shape for non shaped value: " << value.getType();
+    return nullptr;
+  }
+  if (valueSt.hasStaticShape()) {
+    auto rsType = RankedShapeType::get(valueSt.getShape(), dimType);
+    return builder.createOrFold<ConstRankedShapeOp>(loc, rsType);
+  }
+
+  // Dynamic - walk the uses to find a tie_shape op (either this op or an
+  // immediate use).
+  Value rs = findRankedShapeFromUse(value);
+  if (!rs) {
+    builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+        << "dynamically shaped value is missing a shape association via "
+        << "tie_shape";
+    return nullptr;
+  }
+
+  auto rsType = rs.getType().dyn_cast<RankedShapeType>();
+  if (!rsType) {
+    builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+        << "dynamically shaped value is not ranked (which is not yet "
+        << "supported)";
+    return nullptr;
+  }
+
+  if (rsType.getDimType() != dimType) {
+    // TODO(laurenzo): Emit a cast.
+    builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+        << "dynamically shaped shape has the wrong dimension type: "
+        << rsType.getDimType();
+    return nullptr;
+  }
+
+  return rs;
+}
+
 }  // namespace Shape
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.h b/iree/compiler/Dialect/Shape/IR/Builders.h
index b8e9375..cc12575 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.h
+++ b/iree/compiler/Dialect/Shape/IR/Builders.h
@@ -49,6 +49,16 @@
     Value srcShape, int dstRank, SmallVectorImpl<int64_t> &broadcastDims,
     OpBuilder &builder);
 
+// Given a value representing a ShapedType (i.e. tensor or otherwise), attempts
+// to locate a computed RankedShape for it by examining uses for a corresponding
+// tie_shape op, returning the associated RankedShape.
+// In the case of a static shape, a const_ranked_shape will be created and
+// returned. If dimType is provided, then any returned shape will have the
+// given dimType (defaults to IndexType), returning nullptr if this is not
+// possible.
+Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType,
+                                     OpBuilder &builder);
+
 }  // namespace Shape
 }  // namespace iree_compiler
 }  // namespace mlir