Add pass/ops to expand functions to accept/return dynamic dimensions.

* This pass runs early and sets up each function for the later materialization passes.
* Some limited canonicalization that will elide get_ranked_shape ops when trivially resolvable.
* Skeleton of a doc outlining where this is going.

PiperOrigin-RevId: 291452226
diff --git a/docs/dynamic_shapes.md b/docs/dynamic_shapes.md
new file mode 100644
index 0000000..cc59dc4
--- /dev/null
+++ b/docs/dynamic_shapes.md
@@ -0,0 +1,166 @@
+# Dyanmic Shapes
+
+NOTE: Effort is being made to make this facility generic so that it can be
+eventually upstreamed to MLIR in some fashion. However, because MLIR lacks a set
+of frontend ops and generally does not currently have any frontend oriented
+infrastructure, it is being prototyped within IREE in order to find a working
+set of ops and algorithms.
+
+## Levels of dynamicism
+
+In general, there are three levels of shape information that can be present in
+the input IR (or trivially derived by applying some form of shape inferencing
+algorithm). Each additional one imposes more work on the compiler and runtime,
+so generally, the implementation progresses by addressing each once the former
+is well established:
+
+1.  Fully static shapes: No tensors have dynamic dimensions. All tensors are
+    ranked.
+2.  Ranked Dynamicism: All tensors have ranks, but some dimensions may be
+    unspecified.
+3.  Unranked Dynamicism: Some tensors have indeterminate ranks.
+
+At this stage, *Dynamic Shapes* in IREE refers to supporting dynamic ranked
+dynamic tensors, where some dimensions are left unspecified at public function
+boundaries. It is expected that once this is solid, some support can be
+considered for unranked dynamicism, and it is further expected that will entail
+new ops, algorithms and runtime support, apart from what is needed for ranked
+dynamicism.
+
+Within the category of Ranked Dynamicism, it is well known that some dynamic
+dimensions are easier to deal with than others: in common DNN use, outer
+dimensions are much easier and more common with respect to code generation and
+kernel fanout than dynamic inner dimensions.
+
+While the shape handling machinery is relatively generic, we expect that real
+backends will be limited with respect to how much they support all combinations
+of dynamic dimensions. Eventually, IREE intends to solve this by having
+relatively robust CPU fallback for fully dynamic cases and actionable warnings
+that pinpoint when more specificity could increase performance.
+
+## Compiler Frontend
+
+In general, the IREE compiler frontend should accept modules containing
+functions with operands/results that have dynamic dimensions. Such functions may
+also have runtime dependent shapes in the form of `GetShape`-style ops which get
+a shape from an arbitrary tensor, perform some arithmetic on it and use the
+results elsewhere.
+
+### Shape dialect and lowering
+
+IREE is introducing a `shape` dialect with a handful of ops and transformations
+that are useful for materializing dynamic shape computations in terms of high
+level ops on tensors.
+
+#### Types:
+
+*   `ranked_shape`: This value type represents the dynamic dimensions of a
+    partially known, ranked shape. It is used early in the compiler to represent
+    anywhere that dynamic dimensions need to be passed (i.e. function
+    args/results, etc). At lower levels of the compiler, it will generally be
+    dis-aggregated into loose SSA values. This type also carries the datatype
+    used to represent the dimensions. This is currently fixed to i32 but may be
+    leveraged eventually to use smaller integer when such things are known to be
+    legal.
+
+#### Ops:
+
+*   `get_ranked_shape`: Takes a tensor SSA value and returns a corresponding
+    `ranked_shape`. Early in the compilation flow, anything that needs a ranked
+    shape should add such ops so that the compiler can later determine which
+    shape computations to materialize. Getting the `ranked_shape` of a static
+    tensor should yield a constant.
+*   `tie_shape`: Takes tensor and ranked_shape SSA values and returns the
+    tensor. This is used as a junction point by the shape materialization passes
+    to know at various points precisely what the shape is.
+*   ... TODO: need `get_shape_dim` and conversions to/from 1D tensors and loose
+    SSA values.
+
+### Materialization
+
+#### Function signature expansion
+
+Early in the process, all functions should have their arguments and results
+expanded so that any dynamic tensors in their signature will gain a new
+argument/result for the corresponding `ranked_shape`. This is done by expanding
+the signatures and for arguments, inserting placeholder `tie_shape` ops which
+preserve the association for later materialization. For results,
+`get_ranked_shape` ops are inserted.
+
+This is carried out by the `iree-shape-expand-function-dynamic-dims` pass, which
+uses the conversion framework under the hood to perform type expansion.
+
+This pass is typically done early in the compiler flow.
+
+#### Shape dependent codegen
+
+A lot of scheduling logic will need to access shapes (i.e. allocation, workgroup
+size calculation, etc). In general, this should all be done based on a
+`get_ranked_shape` op and corresponding `get_shape_dim` ops. For fully static
+cases, these should reduce down to constants. For dynamic dimensions, the
+`get_ranked_shape` ops serve as anchors where later parts of the compiler know
+they need to materialize shape values.
+
+#### Materializing shape computations
+
+TODO: We have a sketch of this but are still proving it out.
+
+Generally, it should be possible, for any `get_ranked_shape` op, to trace up the
+use-def chain and materialize shape manipulation arithmetic. Once materialized,
+a `tie_shape` op should be inserted to memorialize the junction. Eventually,
+every `get_ranked_shape` op should be follow a `tie_shape` op, and the
+canonicalization rules will elide the `get_ranked_shape`. There is complexity
+around blocks, control flow, etc, but this basic algorithm should be workable.
+
+Work is ongoing upstream to provide a facility to register shape functions with
+ops, which would provide a dynamic, dialect independent way to know what
+arithmetic to materialize. However, in most cases this is not necessary. The
+built-in traits around types and sizes will allow most propagation to happen
+without shape functions. We intend to start with a static set of cases for the
+rest in order to prove the concept.
+
+#### Scalarization
+
+TODO: We have a sketch of this but are still proving it out.
+
+It is quite common in real-world DNN usage to get the 1D tensor representing a
+shape and perform arbitrary tensor ops on it (usually basic arithmetic, slicing,
+concating, tiling, etc). While this is perfectly acceptable from a correctness
+standpoint, it is usually not performant: shapes are typically very small one
+dimensional vectors, and computations on them are usually trivial to reduce to
+small sequences of scalar machine code of a form that CPUs are very good at
+executing. Further, we often want to run these calculations eagerly when
+dispatching functions, etc (i.e. to pre-allocate buffers) and having them
+isolated (versus treating them as arbitrary dense computations) can be quite
+valuable.
+
+We expect that the type bracketing that happens with `ranked_shape` and the
+corresponding ops will make it easy to write some simple DRR patterns to
+identify such shape manipulation sequences and lower them directly to regions of
+`vm` ops operating on scalars. Such regions can be retained and directly emitted
+when lowering to the `vm` dialect and/or CPU code generation and would run with
+low overhead along with any other scheduling code.
+
+While an optimization, we suspect this is an important one.
+
+### Shape inference
+
+TODO: This is mostly placeholder
+
+There is work happening upstream to implement MLIR-integrated shape inference.
+In the mean-time, IREE expects that the input to the compiler has already had
+some shape inference performed on it. In practice, for TensorFlow, there is a
+pass which applies TensorFlow's pre-MLIR shape inference mechanisms to derive
+such things. This has limitations but is a reasonable starting point.
+
+## Compiler Backends
+
+TODO: This is mostly placeholder.
+
+Much of the newer structured-ops based codegen is capable of working (within
+bounds) with ranked dynamic shapes without much work. Given the lack of an e2e
+story, much of this has been done "by way of code review" and there are
+certainly issues to be resolved.
+
+In addition, there are several ABI issues and negotiations with the backend that
+still need to be fleshed out.
diff --git a/iree/compiler/Dialect/Shape/CMakeLists.txt b/iree/compiler/Dialect/Shape/CMakeLists.txt
index cf24777..eb6bb11 100644
--- a/iree/compiler/Dialect/Shape/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/CMakeLists.txt
@@ -13,3 +13,4 @@
 # limitations under the License.
 
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index 970ae9c..2c8f24a 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -35,6 +36,77 @@
 namespace Shape {
 
 //===----------------------------------------------------------------------===//
+// Canonicalization
+//===----------------------------------------------------------------------===//
+
+class ElideTiedGetRankedShapePattern
+    : public OpRewritePattern<GetRankedShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(GetRankedShapeOp op,
+                                     PatternRewriter &rewriter) const override {
+    // If the immediate predecessor is a TieShapeOp, then this op can be
+    // erased in favor of the input to the tie op.
+    if (!matchPattern(op.operand(), m_Op<TieShapeOp>())) {
+      return matchFailure();
+    }
+
+    auto tieOp = cast<TieShapeOp>(op.operand().getDefiningOp());
+    rewriter.replaceOp(op, tieOp.shape(), op.operand());
+
+    return matchSuccess();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// iree.tie_shape
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTieShapeOp(OpAsmParser &parser, OperationState &state) {
+  SmallVector<OpAsmParser::OperandType, 2> operands;
+  SmallVector<Type, 2> operandTypes;
+  if (parser.parseOperandList(operands) ||
+      parser.parseColonTypeList(operandTypes) ||
+      parser.parseOptionalAttrDict(state.attributes) ||
+      parser.resolveOperands(operands, operandTypes, parser.getNameLoc(),
+                             state.operands)) {
+    return failure();
+  }
+
+  // The result type is the same as the first operand.
+  if (state.operands.empty()) return failure();
+  state.types.push_back(state.operands.front().getType());
+  return success();
+}
+
+static void printTieShapeOp(OpAsmPrinter &p, TieShapeOp op) {
+  p << op.getOperationName() << " ";
+  p.printOperands(op.getOperands());
+  p << " : ";
+  interleaveComma(op.getOperandTypes(), p);
+  p.printOptionalAttrDict(op.getOperation()->getAttrs());
+}
+
+static LogicalResult verifyTieShapeOp(TieShapeOp op) {
+  if (op.operand().getType() != op.result().getType()) {
+    return op.emitOpError("operand and result must be the same type");
+  }
+
+  // tie_shape currently only supports ranked tensors.
+  auto rankedTensorType = op.operand().getType().dyn_cast<RankedTensorType>();
+  auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
+  if (!rankedTensorType || !rsType) {
+    return op.emitOpError("currently only ranked tensors are supported");
+  }
+
+  SmallVector<int64_t, 4> rsDims;
+  rsType.getAllDims(rsDims);
+  if (!rankedTensorType.getShape().equals(rsDims)) {
+    return op.emitOpError("dims must match between tensor and shape");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // iree.get_ranked_shape
 //===----------------------------------------------------------------------===//
 
@@ -49,7 +121,7 @@
 }
 
 static void printGetRankedShapeOp(OpAsmPrinter &p, GetRankedShapeOp op) {
-  p << "shape.get_ranked_shape ";
+  p << op.getOperationName() << " ";
   p.printOperand(op.operand());
   p << " : ";
   p.printType(op.operand().getType());
@@ -72,6 +144,11 @@
   return success();
 }
 
+void GetRankedShapeOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<ElideTiedGetRankedShapePattern>(context);
+}
+
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
 
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
index 20f5ea3..1400cd6 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
@@ -34,6 +34,23 @@
 // Dynamic shape support
 //===----------------------------------------------------------------------===//
 
+def Shape_TieShapeOp : Shape_PureOp<"tie_shape"> {
+  let summary = "Ties a tensor and a shape together.";
+  let description = [{
+    Ties a specific tensor and its shape together in the IR, allowing further
+    conversions to re-associate the two. This has no runtime implication and
+    will be removed late in conversion.
+
+    Usage:
+      %0 = shape.tie_shape %1, %2 : tensor<...>, shape.ranked_shape<...>
+  }];
+
+  let arguments = (ins AnyTensor:$operand, Shape_RankedShape:$shape);
+  let results = (outs AnyTensor:$result);
+
+  let verifier = [{ return verify$cppClass(*this); }];
+}
+
 def Shape_GetRankedShapeOp : Shape_PureOp<"get_ranked_shape"> {
   let summary = "Gets the RankedShape associated with the given Tensor.";
   let description = [{
@@ -43,12 +60,23 @@
 
     Getting the RankedShape of a statically shaped tensor will canonicalize
     to a static_ranked_shape op and will never cause a further SSA dependency.
+
+    Usage:
+      %0 = shape.get_ranked_shape %arg0 : tensor<2x?xf32> ->
+          !shape.ranked_shape<2x?xf32>
+
+    Canonicalization: This op includes a canonicalization pattern such that
+    if its operand is supplied by a tie_shape op, then it will replace itself
+    with the tie_shape's shape() operand. In this way, a function with all
+    shapes materialized and tied to intermediate tensors should canonicalize
+    to contain no get_ranked_shape ops.
   }];
 
   let arguments = (ins AnyTensor:$operand);
   let results = (outs Shape_RankedShape:$shape);
 
   let verifier = [{ return verify$cppClass(*this); }];
+  let hasCanonicalizer = 1;
 }
 
 #endif  // IREE_DIALECT_SHAPE_OPS
diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
new file mode 100644
index 0000000..6a9f60c
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -canonicalize %s | IreeFileCheck %s
+
+
+// CHECK-LABEL: @elideTiedGetRankedShape
+// CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<1x?x2x?xf32>
+// CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shape.ranked_shape<1x?x2x?xi32>
+func @elideTiedGetRankedShape(%arg0: tensor<1x?x2x?xf32>, %arg1: !shape.ranked_shape<1x?x2x?xi32>) -> (tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>) {
+  // Note that canonicalization does *not* remove tie_shape. That must be
+  // removed manually once all shape materialization is complete (otherwise,
+  // information needed to materialize would be lost).
+  // CHECK: %[[TIE_T:.+]] = shape.tie_shape %[[T]], %[[SHAPE]]
+  %0 = shape.tie_shape %arg0, %arg1 : tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>
+  // CHECK-NOT: shape.get_ranked_shape
+  %1 = shape.get_ranked_shape %0 : tensor<1x?x2x?xf32> -> !shape.ranked_shape<1x?x2x?xi32>
+  // CHECK-DAG: return %[[TIE_T]], %[[SHAPE]]
+  return %0, %1 : tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>
+}
diff --git a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
new file mode 100644
index 0000000..c0d3496
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
@@ -0,0 +1,22 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics %s
+
+// -----
+func @tie_shape_mismatch_type(%arg0 : tensor<2x?x4xf32>, %arg1 : !shape.ranked_shape<1xi32>) {
+  // expected-error @+1 {{dims must match between tensor and shape}}
+  %0 = shape.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shape.ranked_shape<1xi32>
+  return
+}
+
+// -----
+func @get_ranked_shape_same_rank(%arg0 : tensor<2x?x4xf32>) {
+  // expected-error @+1 {{op operand and result must be of same rank}}
+  %0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2xi32>
+  return
+}
+
+// -----
+func @get_ranked_shape_not_equal_dims(%arg0 : tensor<2x?x4xf32>) {
+  // expected-error @+1 {{op operand tensor and result shape must be equal}}
+  %0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2x2x4xi32>
+  return
+}
diff --git a/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir b/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
index 9f1163e..94601db 100644
--- a/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
@@ -1,6 +1,14 @@
 // RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
 
 // -----
+// CHECK-LABEL: @parse_print_tie_shape
+func @parse_print_tie_shape(%arg0 : tensor<2x?x4xf32>, %arg1 : !shape.ranked_shape<2x?x4xi32>) {
+  %0 = shape.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shape.ranked_shape<2x?x4xi32>
+  return
+}
+
+
+// -----
 // CHECK-LABEL: @parse_print_get_ranked_shape
 func @parse_print_get_ranked_shape(%arg0 : tensor<2x?x4xi32>) {
   // CHECK: shape.get_ranked_shape %arg0 : tensor<2x?x4xi32> -> !shape.ranked_shape<2x?x4xi32>
diff --git a/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir b/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
index 25a9515..8d039bd 100644
--- a/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
@@ -25,17 +25,3 @@
 func @error(%arg0 : !shape.ranked_shape<1x?xf32>) {
   return
 }
-
-// -----
-func @get_ranked_shape_same_rank(%arg0 : tensor<2x?x4xf32>) {
-  // expected-error @+1 {{op operand and result must be of same rank}}
-  %0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2xi32>
-  return
-}
-
-// -----
-func @get_ranked_shape_not_equal_dims(%arg0 : tensor<2x?x4xf32>) {
-  // expected-error @+1 {{op operand tensor and result shape must be equal}}
-  %0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2x2x4xi32>
-  return
-}
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
new file mode 100644
index 0000000..59b7444
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "Transforms",
+    srcs = [
+        "ExpandFunctionDynamicDims.cpp",
+    ],
+    hdrs = [
+        "Passes.h",
+    ],
+    deps = [
+        "//iree/compiler/Dialect/Shape/IR",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Transforms",
+    ],
+    alwayslink = 1,
+)
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..a875697
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_cc_library(
+  NAME
+    Transforms
+  HDRS
+    "Passes.h"
+  SRCS
+    "ExpandFunctionDynamicDims.cpp"
+  DEPS
+    iree::compiler::Dialect::Shape::IR
+    LLVMSupport
+    MLIRIR
+    MLIRPass
+    MLIRSupport
+    MLIRTransformUtils
+    MLIRTransforms
+  ALWAYSLINK
+  PUBLIC
+)
diff --git a/iree/compiler/Dialect/Shape/Transforms/ExpandFunctionDynamicDims.cpp b/iree/compiler/Dialect/Shape/Transforms/ExpandFunctionDynamicDims.cpp
new file mode 100644
index 0000000..fa9fb76
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/ExpandFunctionDynamicDims.cpp
@@ -0,0 +1,215 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class DynamicDimsTypeConverter : public TypeConverter {
+ public:
+  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
+    auto tensorType = t.dyn_cast<RankedTensorType>();
+    if (!tensorType || tensorType.getNumDynamicDims() == 0) {
+      // No conversion - not ranked tensor or static.
+      results.push_back(t);
+      return success();
+    }
+
+    // Dimension is hard-coded to 32bits currently but better decisions are
+    // possible in some situations.
+    auto dimType = IntegerType::get(32, t.getContext());
+    auto shapeType =
+        Shape::RankedShapeType::get(tensorType.getShape(), dimType);
+    // Expand tensor<?...x*> -> (tensor<...>, ranked_shape<...xi32>)
+    results.push_back(t);
+    results.push_back(shapeType);
+    return success();
+  }
+
+  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+                                   ArrayRef<Value> inputs,
+                                   Location loc) override {
+    // Adds a conversion from (%0 = tensor<...>, %1 = ranked_shape<...>) inputs
+    // to: shape.tie_shape %0, %1
+    assert(inputs.size() == 2);
+    return rewriter.create<Shape::TieShapeOp>(loc, resultType, inputs[0],
+                                              inputs[1]);
+  }
+};
+
+class FuncOpConversion : public OpConversionPattern<FuncOp> {
+ public:
+  FuncOpConversion(DynamicDimsTypeConverter &typeConverter,
+                   MLIRContext *context)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      FuncOp fnOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto fnType = fnOp.getType();
+
+    // TODO(laurenzo): Need to handle all terminators so conservatively
+    // limiting to single block functions until implemented.
+    if (fnOp.getBody().getBlocks().size() != 1) {
+      fnOp.emitWarning()
+          << "dynamic shape conversion only supported for single block "
+          << "functions (currently)";
+      return matchFailure();
+    }
+
+    // Convert function arguments.
+    TypeConverter::SignatureConversion signatureConverter(
+        fnType.getNumInputs());
+    for (unsigned i = 0, e = fnType.getNumInputs(); i < e; ++i) {
+      if (failed(typeConverter.convertSignatureArg(i, fnType.getInput(i),
+                                                   signatureConverter))) {
+        return matchFailure();
+      }
+    }
+
+    // Convert function results.
+    SmallVector<Type, 1> convertedResultTypes;
+    if (failed(typeConverter.convertTypes(fnType.getResults(),
+                                          convertedResultTypes))) {
+      return matchFailure();
+    }
+
+    // Replace function.
+    auto newFnOp = rewriter.cloneWithoutRegions(fnOp);
+    rewriter.inlineRegionBefore(fnOp.getBody(), newFnOp.getBody(),
+                                newFnOp.end());
+    newFnOp.setType(rewriter.getFunctionType(
+        signatureConverter.getConvertedTypes(), convertedResultTypes));
+    rewriter.applySignatureConversion(&newFnOp.getBody(), signatureConverter);
+    rewriter.eraseOp(fnOp);
+
+    // Rewrite the terminator to match the result type conversion that was
+    // performed.
+    auto terminator = newFnOp.getBody().front().getTerminator();
+    auto ip = rewriter.saveInsertionPoint();
+    rewriter.setInsertionPoint(terminator);
+    SmallVector<Value, 4> newTerminatorOperands;
+    for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i) {
+      auto operand = terminator->getOperand(i);
+      SmallVector<Type, 2> expandedTypes;
+      if (failed(typeConverter.convertType(operand.getType(), expandedTypes))) {
+        continue;
+      }
+
+      // Non-conversion
+      if (expandedTypes.size() == 1) {
+        newTerminatorOperands.push_back(operand);
+        continue;
+      }
+      assert(expandedTypes.size() == 2 &&
+             "type converter should expand 1 -> 2");
+
+      // Expand (tensor<...>) to (tensor<...>, ranked_shape<...>)
+      auto shape = rewriter.create<Shape::GetRankedShapeOp>(
+          terminator->getLoc(), expandedTypes[1], operand);
+      newTerminatorOperands.push_back(operand);
+      newTerminatorOperands.push_back(shape);
+    }
+
+    // Clone the terminator (assumed to be 'return'-like) with modified
+    // operands.
+    OperationState terminatorState(terminator->getLoc(), terminator->getName());
+    terminatorState.addOperands(newTerminatorOperands);
+    rewriter.createOperation(terminatorState);
+    rewriter.eraseOp(terminator);
+
+    rewriter.restoreInsertionPoint(ip);
+    return matchSuccess();
+  }
+
+ private:
+  DynamicDimsTypeConverter &typeConverter;
+};
+
+bool isLegallyShapedSignatureType(Type thisType, Type nextType) {
+  if (!thisType.isa<TensorType>()) return true;  // Legal: Don't care.
+  auto rankedType = thisType.dyn_cast<RankedTensorType>();
+  if (!rankedType) return false;  // Illegal: Non-ranked tensor
+  if (rankedType.getNumDynamicDims() == 0) return true;  // Legal: Static shape
+
+  // At this point, the type is ranked and has dynamic dims. Validate.
+  auto rankedShapeType = nextType.dyn_cast_or_null<Shape::RankedShapeType>();
+  if (!rankedShapeType) return false;  // Illegal: No following shape.
+
+  // Are dims equal.
+  auto thisDims = rankedType.getShape();
+  SmallVector<int64_t, 7> shapeDims;
+  rankedShapeType.getAllDims(shapeDims);
+  if (!thisDims.equals(shapeDims)) return false;  // Illegal: Mismatched shape.
+  return true;  // Legal: dynamic tensor followed by matching shape.
+}
+
+// Determines whether a function is "legally shaped", which means that its
+// shaped inputs/results are either a) statically shaped or b) followed by
+// an appropriate (ranked_shape) argument/result with corresponding
+// dims.
+bool isLegallyShapedFunction(FuncOp fnOp) {
+  auto fnType = fnOp.getType();
+  // Validate arguments.
+  for (unsigned i = 0, e = fnType.getNumInputs(); i < e; ++i) {
+    Type type = fnType.getInput(i);
+    Type nextType = (i + 1 < e) ? fnType.getInput(i + 1) : nullptr;
+    if (!isLegallyShapedSignatureType(type, nextType)) return false;
+  }
+  // Validate results.
+  return true;
+}
+
+class ExpandFunctionDynamicDimsPass
+    : public ModulePass<ExpandFunctionDynamicDimsPass> {
+  void runOnModule() override {
+    ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<FuncOp>(isLegallyShapedFunction);
+    target.markOpRecursivelyLegal<FuncOp>();
+
+    OwningRewritePatternList patterns;
+    DynamicDimsTypeConverter typeConverter;
+    patterns.insert<FuncOpConversion>(typeConverter, &getContext());
+
+    if (failed(applyPartialConversion(getModule(), target, patterns,
+                                      &typeConverter))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+// For any function which contains dynamic dims in its inputs or results,
+// rewrites it so that the dynamic dims are passed in/out.
+std::unique_ptr<OpPassBase<ModuleOp>> createExpandFunctionDynamicDimsPass() {
+  return std::make_unique<ExpandFunctionDynamicDimsPass>();
+}
+
+static PassRegistration<ExpandFunctionDynamicDimsPass> pass(
+    "iree-shape-expand-function-dynamic-dims",
+    "Expands dynamic dimensions in function signatures.");
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Transforms/Passes.h b/iree/compiler/Dialect/Shape/Transforms/Passes.h
new file mode 100644
index 0000000..ed22bed
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/Passes.h
@@ -0,0 +1,32 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
+#define IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// For any function which contains dynamic dims in its inputs or results,
+// rewrites it so that the dynamic dims are passed in/out.
+std::unique_ptr<OpPassBase<ModuleOp>> createExpandFunctionDynamicDimsPass();
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/BUILD b/iree/compiler/Dialect/Shape/Transforms/test/BUILD
new file mode 100644
index 0000000..14281d1
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/expand_function_dynamic_dims.mlir b/iree/compiler/Dialect/Shape/Transforms/test/expand_function_dynamic_dims.mlir
new file mode 100644
index 0000000..3358b57
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/test/expand_function_dynamic_dims.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-shape-expand-function-dynamic-dims %s | IreeFileCheck %s
+
+// CHECK-LABEL: @staticFunctionArgs
+// CHECK-NOT: ranked_shape
+func @staticFunctionArgs(%arg0 : tensor<1x2xf32>) {
+  return
+}
+
+// -----
+// CHECK-LABEL: @dynamicFunctionArgs
+// Should insert function shape argument and result.
+// CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<1x?x2x?xf32>
+// CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shape.ranked_shape<1x?x2x?xi32>
+// CHECK-SAME: -> (tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>)
+func @dynamicFunctionArgs(%arg0 : tensor<1x?x2x?xf32>) -> tensor<1x?x2x?xf32> {
+  // Should insert tie on arguments and get_shape on result shape.
+  // CHECK-DAG: %[[TIE_T:.+]] = shape.tie_shape %[[T]], %[[SHAPE]]
+  // CHECK-DAG: %[[GET_SHAPE:.+]] = shape.get_ranked_shape %[[TIE_T]]
+  // CHECK-DAG: return %[[TIE_T]], %[[GET_SHAPE]] : tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>
+  return %arg0 : tensor<1x?x2x?xf32>
+}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 6d5bd60..2bc0daa 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -57,6 +57,7 @@
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/HAL/Transforms",
         "//iree/compiler/Dialect/Shape/IR",
+        "//iree/compiler/Dialect/Shape/Transforms",
         "//iree/compiler/Dialect/VM/Analysis",
         "//iree/compiler/Dialect/VM/Conversion/StandardToVM",
         "//iree/compiler/Dialect/VM/IR",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 2eaf648..aa8087e 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -67,6 +67,7 @@
       iree::compiler::Dialect::HAL::IR
       iree::compiler::Dialect::HAL::Transforms
       iree::compiler::Dialect::Shape::IR
+      iree::compiler::Dialect::Shape::Transforms
       iree::compiler::Dialect::VM::Analysis
       iree::compiler::Dialect::VM::Conversion
       iree::compiler::Dialect::VM::Conversion::StandardToVM