Add a utility to find or build a ranked shape given an arbitrary value.
(note that this is part of a larger chain, being broken up for submission and tests are included in subsequent changes)
PiperOrigin-RevId: 299399690
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp
index a91fed2..9ce8bab 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp
@@ -16,11 +16,34 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir/IR/Diagnostics.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
+namespace {
+
+Value getRankedShapeFromOp(Operation *op) {
+ auto tieOp = llvm::dyn_cast_or_null<TieShapeOp>(op);
+ if (!tieOp) return nullptr;
+ auto shape = tieOp.shape();
+ if (!shape.getType().isa<RankedShapeType>()) return nullptr;
+ return shape;
+}
+
+Value findRankedShapeFromUse(Value value) {
+ Value rs = getRankedShapeFromOp(value.getDefiningOp());
+ if (rs) return rs;
+ for (auto &use : value.getUses()) {
+ rs = getRankedShapeFromOp(use.getOwner());
+ if (rs) return rs;
+ }
+ return nullptr;
+}
+
+} // namespace
+
Value buildCastInputsToResultShape(Location loc,
RankedShapeType resultShapeType,
ArrayRef<Value> inputs, OpBuilder &builder) {
@@ -98,6 +121,49 @@
}
}
+Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType,
+ OpBuilder &builder) {
+ if (!dimType) dimType = builder.getIndexType();
+ auto valueSt = value.getType().dyn_cast<ShapedType>();
+ if (!valueSt) {
+ builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+ << "cannot construct shape for non shaped value: " << value.getType();
+ return nullptr;
+ }
+ if (valueSt.hasStaticShape()) {
+ auto rsType = RankedShapeType::get(valueSt.getShape(), dimType);
+ return builder.createOrFold<ConstRankedShapeOp>(loc, rsType);
+ }
+
+ // Dynamic - walk the uses to find a tie_shape op (either this op or an
+ // immediate use).
+ Value rs = findRankedShapeFromUse(value);
+ if (!rs) {
+ builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+ << "dynamically shaped value is missing a shape association via "
+ << "tie_shape";
+ return nullptr;
+ }
+
+ auto rsType = rs.getType().dyn_cast<RankedShapeType>();
+ if (!rsType) {
+ builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+ << "dynamically shaped value is not ranked (which is not yet "
+ << "supported)";
+ return nullptr;
+ }
+
+ if (rsType.getDimType() != dimType) {
+ // TODO(laurenzo): Emit a cast.
+ builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
+ << "dynamically shaped shape has the wrong dimension type: "
+ << rsType.getDimType();
+ return nullptr;
+ }
+
+ return rs;
+}
+
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.h b/iree/compiler/Dialect/Shape/IR/Builders.h
index b8e9375..cc12575 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.h
+++ b/iree/compiler/Dialect/Shape/IR/Builders.h
@@ -49,6 +49,16 @@
Value srcShape, int dstRank, SmallVectorImpl<int64_t> &broadcastDims,
OpBuilder &builder);
+// Given a value representing a ShapedType (i.e. tensor or otherwise), attempts
+// to locate a computed RankedShape for it by examining uses for a corresponding
+// tie_shape op, returning the associated RankedShape.
+// In the case of a static shape, a const_ranked_shape will be created and
+// returned. If dimType is provided, then any returned shape will have the
+// given dimType (defaults to IndexType), returning nullptr if this is not
+// possible.
+Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType,
+ OpBuilder &builder);
+
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir