Basic conversion of `shape` to `shapex` dialect.

- New pass ConvertShapeToShapex lowering the `shape` dialect to `shapex`.
  Mainly, this pass converts `!shape.shape` to `!shapex.ranked_shape`, but
  it will need to do more in the future. Currently, due to limitations in
  the conversion infra (detailed in a comment), this won't work in general
  beyond a single basic block. Need to follow up and either write custom
  conversion infra or improve MLIR core.

- add ConvertShapeToShapex to iree-tf-import-pipeline

- add ops needed for lowering BatchMatMul
  - `shapex.gather_extents`: a powerful shape shuffling op that subsumes
    concat, slicing, permuting, etc.
  - `shapex.to_extent_tensor` and `shapex.from_extent_tensor`: ops for
    bridging between the "tensor of extents" world and
    `!shapex.ranked_shape`

- ConvertHLOToShapeDialect: add support for lowering
  `xla_hlo.dynamic_broadcast_in_dim` to `shapex` ops.

- canonicalization pattern to erase unused make_ranked_shape ops, helpful
  during legalizations

- lower shapex.gather_extents in MaterializeShapeCalculations

PiperOrigin-RevId: 308738215
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index f532100..1c26f93 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -37,9 +37,12 @@
         "//iree/base:signature_mangle",
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Dialect/Shape/Conversion",
+        "//iree/compiler/Dialect/Shape/Transforms",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
diff --git a/integrations/tensorflow/compiler/Passes.cpp b/integrations/tensorflow/compiler/Passes.cpp
index 50931e2..aa19c09 100644
--- a/integrations/tensorflow/compiler/Passes.cpp
+++ b/integrations/tensorflow/compiler/Passes.cpp
@@ -16,6 +16,8 @@
 
 #include "integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.h"
 #include "integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.h"
+#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/Passes.h"
@@ -27,6 +29,14 @@
 // IREE core should go here.
 void createIreeTfImportPipeline(OpPassManager &pm) {
   ////////////////////////////////////////////////////////////////////////////
+  // Lowering shape-related constructs.
+  ////////////////////////////////////////////////////////////////////////////
+  pm.addPass(Shape::createConvertHLOToShapePass());
+  pm.addPass(createConvertShapeToShapexPass());
+  // Clean up trivial redundancies.
+  pm.addPass(createCanonicalizerPass());
+
+  ////////////////////////////////////////////////////////////////////////////
   // Lowering TensorList-related parts of tf dialect to tf_tensorlist dialect.
   ////////////////////////////////////////////////////////////////////////////
   pm.addPass(tf_tensorlist::createConvertTfToTfTensorList());
diff --git a/iree/compiler/Dialect/Shape/Conversion/BUILD b/iree/compiler/Dialect/Shape/Conversion/BUILD
new file mode 100644
index 0000000..133cc09
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/BUILD
@@ -0,0 +1,27 @@
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "ConvertShapeToShapex",
+    srcs = ["ConvertShapeToShapex.cpp"],
+    deps = [
+        "//iree/compiler/Dialect/Shape/IR",
+        "@llvm-project//mlir:Dialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Shape",
+        "@llvm-project//mlir:Transforms",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "Conversion",
+    hdrs = ["Passes.h"],
+    deps = [
+        ":ConvertShapeToShapex",
+        "@llvm-project//mlir:Pass",
+    ],
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
new file mode 100644
index 0000000..2434478
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
@@ -0,0 +1,212 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace Shape {
+namespace {
+
+// This conversion is currently quite limited, such as not handling multiple
+// basic blocks in general, due to doing a type conversion that the MLIR core
+// conversion infra doesn't handle well.
+//
+// In particular, we convert `!shape.shape` to `!shapex.ranked_shape<...>`, but
+// the contents of the `...` are context-dependent. Thus, one could say that
+// this pass does a context-dependent type conversion.
+//
+// The current MLIR conversion infra doesn't handle context-dependent type
+// conversions.
+//
+// I can see two solutions:
+//
+// 1. Extend the MLIR conversion infra to better support context-dependent type
+// conversions. One way to do this would be for the conversion infra to convert
+// blocks in RPO and use the type of the converted successor operand in a
+// dominating predecessor as the type for the block argument when converting a
+// block. A similar thing could be done with an RPO traversal of the callgraph.
+// This algorithm wouldn't work in the presence of recursively dead cycles. And
+// of course linkage boundaries cannot have a context-dependent type conversion
+// (by definition).
+//
+// 2. Avoid needing to convert to !shapex.ranked_shape in the first place. This
+// could be accomplished by generalizing !shape.shape to be able to support the
+// use case of !shapex.ranked_shape. One important requirement here is that
+// !shapex.ranked_shape models a partially-specified shape (hardcoded for the
+// ranked case). !shape.shape could be extended to capture partially-specified
+// shapes in the type, such as allowing `!shape.shape<*>` to model an unranked
+// shape (which is the default; no information), `!shape.shape<?x?x5x?>` to
+// model a rank-4 shape with dimension 2 being of extent 5, etc.
+//
+// Once we have this, we could do this lowering from generic !shape.shape to
+// statically-known ranked shapes more progressively and treat it more like a
+// type refinement algorithm.
+//
+// The main risk is that we are trying to shove too much stuff into the
+// !shape.shape type. There's a risk that "progressive lowering" becomes "no
+// clear boundaries" and we end up with code deep into the compiler continuously
+// needing to doublecheck that the !shape.shape's at this point are in fact
+// statically known to be ranked, or silently making that assumption and
+// triggering assertions on verifier-valid IR. Pipelines and legalization
+// targets could make these assertions not fire in practice, but it would
+// be a maintenance burden.
+
+class ConvertShapeOfOp : public OpConversionPattern<shape::ShapeOfOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      shape::ShapeOfOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto tensorType = operands[0].getType().dyn_cast<RankedTensorType>();
+    if (!tensorType) {
+      return failure();
+    }
+    auto resultType =
+        RankedShapeType::get(tensorType.getShape(), rewriter.getContext());
+    rewriter.replaceOpWithNewOp<Shape::GetRankedShapeOp>(op, resultType,
+                                                         operands[0]);
+    return success();
+  }
+};
+
+class ConvertSplitAtOp : public OpConversionPattern<shape::SplitAtOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      shape::SplitAtOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    IntegerAttr indexAttr;
+    if (!matchPattern(op.index(), m_Constant(&indexAttr))) {
+      return rewriter.notifyMatchFailure(op, "requires constant `index`");
+    }
+    auto rank = operands[0].getType().cast<RankedShapeType>().getRank();
+    int64_t index = indexAttr.getInt();
+    if (index < 0) {
+      index += rank;
+    }
+    auto head_indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, index));
+    auto tail_indices = llvm::to_vector<4>(llvm::seq<int64_t>(index, rank));
+    Value head = rewriter.create<GatherExtentsOp>(
+        op.getLoc(), operands[0], rewriter.getI64TensorAttr(head_indices));
+    Value tail = rewriter.create<GatherExtentsOp>(
+        op.getLoc(), operands[0], rewriter.getI64TensorAttr(tail_indices));
+    rewriter.replaceOp(op, {head, tail});
+    return success();
+  }
+};
+
+class ConvertBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      shape::BroadcastOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    Value lhs = operands[0];
+    Value rhs = operands[1];
+    auto lhsType = lhs.getType().cast<RankedShapeType>();
+    auto rhsType = rhs.getType().cast<RankedShapeType>();
+    // Establish invariant that rank(lhs) <= rank(rhs)
+    if (lhsType.getRank() > rhsType.getRank()) {
+      std::swap(lhsType, rhsType);
+      std::swap(lhs, rhs);
+    }
+    SmallVector<int64_t, 6> resultShape;
+    OpTrait::util::getBroadcastedShape(lhsType.getAllDims(),
+                                       rhsType.getAllDims(), resultShape);
+    auto resultType = RankedShapeType::get(resultShape, rewriter.getContext());
+    auto iota = llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsType.getRank()));
+    rewriter.replaceOpWithNewOp<RankedBroadcastShapeOp>(
+        op, resultType, lhs, rhs,
+        /*lhs_broadcast_dimensions=*/
+        rewriter.getI64TensorAttr(makeArrayRef(iota).drop_front(
+            rhsType.getRank() - lhsType.getRank())),
+        /*rhs_broadcast_dimensions=*/
+        rewriter.getI64TensorAttr(iota));
+    return success();
+  }
+};
+
+class ConvertConcatOp : public OpConversionPattern<shape::ConcatOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      shape::ConcatOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto resultRank = operands[0].getType().cast<RankedShapeType>().getRank() +
+                      operands[1].getType().cast<RankedShapeType>().getRank();
+    auto indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, resultRank));
+    rewriter.replaceOpWithNewOp<Shape::GatherExtentsOp>(
+        op, ValueRange({operands[0], operands[1]}),
+        rewriter.getI64TensorAttr(indices));
+    return success();
+  }
+};
+
+class ConvertToExtentTensorOp
+    : public OpConversionPattern<shape::ToExtentTensorOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
+                                                         operands[0]);
+    return success();
+  }
+};
+
+class ConvertShapeToShapex
+    : public PassWrapper<ConvertShapeToShapex, OperationPass<ModuleOp>> {
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    MLIRContext *context = &getContext();
+
+    // Conversion target definition.
+    ConversionTarget conversionTarget(*context);
+    conversionTarget.addIllegalDialect<shape::ShapeDialect>();
+    conversionTarget.addLegalDialect<iree_compiler::ShapeDialect>();
+
+    // Patterns.
+    OwningRewritePatternList patterns;
+    patterns.insert<ConvertShapeOfOp>(context);
+    patterns.insert<ConvertSplitAtOp>(context);
+    patterns.insert<ConvertBroadcastOp>(context);
+    patterns.insert<ConvertConcatOp>(context);
+    patterns.insert<ConvertToExtentTensorOp>(context);
+
+    if (failed(applyPartialConversion(module, conversionTarget, patterns))) {
+      return signalPassFailure();
+    }
+  }
+};
+}  // namespace
+
+}  // namespace Shape
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass() {
+  return std::make_unique<Shape::ConvertShapeToShapex>();
+}
+
+static PassRegistration<Shape::ConvertShapeToShapex> registration(
+    "convert-shape-to-shapex", "Convert `shape` dialect to `shapex` dialect");
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Conversion/Passes.h b/iree/compiler/Dialect/Shape/Conversion/Passes.h
new file mode 100644
index 0000000..750bb86
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/Passes.h
@@ -0,0 +1,29 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
+#define IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Convert `shape` dialect to `shapex` dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass();
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/BUILD b/iree/compiler/Dialect/Shape/Conversion/test/BUILD
new file mode 100644
index 0000000..14281d1
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
new file mode 100644
index 0000000..02c9e9c
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt -convert-shape-to-shapex -split-input-file -verify-diagnostics -allow-unregistered-dialect <%s | IreeFileCheck %s
+
+// -----
+// shape.shape_of
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>) {
+  // CHECK: shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  "foo.use"(%0) : (!shape.shape) -> ()
+  return
+}
+
+// -----
+// shape.split_at
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>) {
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  %index = constant 0 : i32
+  // CHECK: %[[RS:.+]] = shapex.get_ranked_shape %arg0
+  // CHECK: %[[HEAD:.+]] = "shapex.gather_extents"(%[[RS]]) {indices = dense<[]> : tensor<0xi64>} : (!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[]>
+  // CHECK: %[[TAIL:.+]] = "shapex.gather_extents"(%[[RS]]) {indices = dense<0> : tensor<1xi64>} : (!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+  // CHECK: "foo.use"(%[[HEAD]], %[[TAIL]])
+  %head, %tail = "shape.split_at"(%0, %index) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+  "foo.use"(%head, %tail) : (!shape.shape, !shape.shape) -> ()
+  return
+}
+
+// -----
+// No conversion -- index is dynamic.
+func @f(%arg0: tensor<?xf32>, %index: i32) {
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  // expected-error @+1 {{failed to legalize operation}}
+  %head, %tail = "shape.split_at"(%0, %index) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+  "foo.use"(%head, %tail) : (!shape.shape, !shape.shape) -> ()
+  return
+}
+
+// -----
+// shape.broadcast
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  %1 = "shape.shape_of"(%arg1) : (tensor<?xf32>) -> !shape.shape
+  // CHECK: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_shape"(%[[LHSRS]], %[[RHSRS]]) {
+  // CHECK-SAME: lhs_broadcast_dimensions = dense<0> : tensor<1xi64>,
+  // CHECK-SAME: rhs_broadcast_dimensions = dense<0> : tensor<1xi64>}
+  // CHECK-SAME: : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+  %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  // CHECK: "foo.use"(%[[BROADCASTED]])
+  "foo.use"(%2) : (!shape.shape) -> ()
+  return
+}
+
+// -----
+// shape.concat
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  %1 = "shape.shape_of"(%arg1) : (tensor<?xf32>) -> !shape.shape
+  // CHECK: %[[CONCATTED:.+]] = "shapex.gather_extents"(%[[LHSRS]], %[[RHSRS]]) {indices = dense<[0, 1]> : tensor<2xi64>}
+  %2 = "shape.concat"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  // CHECK: "foo.use"(%[[CONCATTED]])
+  "foo.use"(%2) : (!shape.shape) -> ()
+  return
+}
+
+// -----
+// shape.to_extent_tensor
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %[[RS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+  // CHECK: %[[EXTENTS:.+]] = "shapex.to_extent_tensor"(%[[RS]])
+  %1 = "shape.to_extent_tensor"(%0) : (!shape.shape) -> tensor<1xindex>
+  // CHECK: "foo.use"(%[[EXTENTS]])
+  "foo.use"(%1) : (tensor<1xindex>) -> ()
+  return
+}
+
diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD
index 8d0bb7d..c7e0595 100644
--- a/iree/compiler/Dialect/Shape/IR/BUILD
+++ b/iree/compiler/Dialect/Shape/IR/BUILD
@@ -52,6 +52,7 @@
         "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InferTypeOpInterface",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:SideEffects",
         "@llvm-project//mlir:StandardOps",
@@ -72,6 +73,7 @@
         ":td_files",
         "//iree/compiler/Dialect/IREE/IR:td_files",
         "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
         "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
         "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
     ],
@@ -88,6 +90,7 @@
         ":td_files",
         "//iree/compiler/Dialect/IREE/IR:td_files",
         "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
         "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
         "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
     ],
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index b96665e..c8550f2 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -56,7 +56,9 @@
   // If the immediate predecessor is a TieShapeOp, then this op can be
   // erased in favor of the input to the tie op.
   auto tieOp = dyn_cast_or_null<TieShapeOp>(operands.operand().getDefiningOp());
-  if (!tieOp) return failure();
+  if (!tieOp) {
+    return rewriter.notifyMatchFailure(op, "no associated tie_shape op");
+  }
 
   rewriter.replaceOp(op, tieOp.shape());
   return success();
@@ -135,6 +137,26 @@
   return success();
 }
 
+// TODO(silvasean): Better handling of "erase unused ops for legality".
+// Currently, the way that we legalize !shapex.ranked_shape into individual SSA
+// values per dimension is to iteratively reduce other ops to
+// shapex.ranked_dim/shapex.ranked_dims and shapex.make_ranked_shape and then
+// have patterns that know how to resolve the
+// shapex.ranked_dim/shapex.ranked_dims to scalar values by looking through the
+// shapex.make_ranked_shape ops, with the eventual goal of not having any uses
+// of the shapex.make_ranked_shape op itself, instead the main computation flow
+// using the individual SSA values. This naturally produces a lot of unused
+// shapex.make_ranked_shape ops which we need to delete for legality reasons.
+// This pattern allows conversions to erase those ops.
+LogicalResult eraseUnusedMakeRankedShapeOp(
+    MakeRankedShapeOp op, MakeRankedShapeOpOperandAdaptor operands,
+    PatternRewriter &rewriter) {
+  if (!op.getResult().use_empty())
+    return rewriter.notifyMatchFailure(op, "op has uses");
+  rewriter.eraseOp(op);
+  return success();
+}
+
 LogicalResult dynamicMakeRankedShapeDimPattern(
     RankedDimOp op, RankedDimOpOperandAdaptor operands,
     PatternRewriter &rewriter) {
@@ -268,6 +290,28 @@
 }
 
 //===----------------------------------------------------------------------===//
+// shapex.from_extent_tensor
+//===----------------------------------------------------------------------===//
+
+LogicalResult fromExtentTensorOfToExtentTensorIsIdentity(
+    FromExtentTensorOp op, FromExtentTensorOpOperandAdaptor operands,
+    PatternRewriter &rewriter) {
+  auto toOp =
+      dyn_cast_or_null<ToExtentTensorOp>(op.extent_tensor().getDefiningOp());
+  if (!toOp) {
+    return failure();
+  }
+  rewriter.replaceOp(op, toOp.shape());
+  return success();
+}
+
+void FromExtentTensorOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  insertGreedyPattern(patterns, context,
+                      fromExtentTensorOfToExtentTensorIsIdentity);
+}
+
+//===----------------------------------------------------------------------===//
 // Standard folding and canonicalization conversion patterns.
 //===----------------------------------------------------------------------===//
 
@@ -294,6 +338,7 @@
 void populateFoldConversionPatterns(MLIRContext *context,
                                     OwningRewritePatternList &patterns) {
   patterns.insert<TieShapeTypeConversionPattern>(context);
+  insertConversionPattern(patterns, context, eraseUnusedMakeRankedShapeOp);
   insertConversionPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
   insertConversionPattern(patterns, context,
                           elideDuplicateGetRankedShapePattern);
@@ -303,6 +348,8 @@
   insertConversionPattern(patterns, context, identityMakeRankedShapePattern);
   insertConversionPattern(patterns, context, elideStaticGetRankedShapePattern);
   insertConversionPattern(patterns, context, safeCastCompatibleShapePattern);
+  insertConversionPattern(patterns, context,
+                          fromExtentTensorOfToExtentTensorIsIdentity);
 }
 
 }  // namespace Shape
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeBase.td b/iree/compiler/Dialect/Shape/IR/ShapeBase.td
index 09029f1..654e8cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeBase.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeBase.td
@@ -52,4 +52,9 @@
 // be a useful feature.
 def Shape_DimType : AnyTypeOf<[Index, AnySignlessInteger]>;
 
+def Shape_ExtentTensor : ShapedContainerType<
+    [Index, AnySignlessInteger],
+    And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
+    "a 1D tensor of extents">;
+
 #endif  // IREE_DIALECT_SHAPE_BASE
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index f03bca3..672c6cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -240,6 +240,115 @@
   RankedDimsOp::build(builder, result, builder->getIndexType(), shape);
 }
 
+//===----------------------------------------------------------------------===//
+// shape.gather_extents
+//===----------------------------------------------------------------------===//
+
+// Helper for accessing attributes for inferReturnTypes callback.
+// That helper gives the attributes as an `ArrayRef<NamedAttribute>` which isn't
+// the nicest form.
+template <typename Attr>
+static Attr getRequiredAttr(ArrayRef<NamedAttribute> attributes,
+                            StringRef name) {
+  auto it = llvm::find_if(
+      attributes, [&](NamedAttribute attr) { return attr.first == name; });
+  assert(it != attributes.end());
+  return it->second.template cast<Attr>();
+}
+
+/*static*/ SmallVector<int64_t, 6> GatherExtentsOp::getConcatenatedExtents(
+    ValueRange values) {
+  SmallVector<int64_t, 6> ret;
+  for (auto type : values.getTypes()) {
+    auto rankedShape = type.cast<RankedShapeType>();
+    ret.append(rankedShape.getAllDims().begin(),
+               rankedShape.getAllDims().end());
+  }
+  return ret;
+}
+
+static LogicalResult verifyGatherExtentsOp(GatherExtentsOp op) {
+  int64_t totalExtents = 0;
+  for (Type type : op.shapes().getTypes()) {
+    totalExtents += type.cast<RankedShapeType>().getRank();
+  }
+
+  for (int64_t index : op.indices().getValues<int64_t>()) {
+    if (index >= totalExtents) {
+      return op.emitError() << "index " << index
+                            << " exceeds total number of extents of operands ("
+                            << totalExtents << ")";
+    }
+  }
+
+  return success();
+}
+
+LogicalResult GatherExtentsOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // We can't infer the DimType of the result if there are no operands.
+  // If a user requires this, then they should manually specify the return type.
+  // We could in theory use an index type here (the default).
+  assert(!operands.empty() && "inferring return type for empty operands");
+  auto indices = getRequiredAttr<DenseIntElementsAttr>(attributes, "indices")
+                     .getValues<int64_t>();
+  auto inputExtents = getConcatenatedExtents(operands);
+  SmallVector<int64_t, 6> resultExtents;
+  for (auto index : indices) {
+    resultExtents.push_back(inputExtents[index]);
+  }
+  inferredReturnTypes.push_back(RankedShapeType::get(resultExtents, context));
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// shapex.to_extent_tensor
+//===----------------------------------------------------------------------===//
+
+LogicalResult ToExtentTensorOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto inputType = operands[0].getType().cast<RankedShapeType>();
+  inferredReturnTypes.push_back(
+      RankedTensorType::get({inputType.getRank()}, IndexType::get(context)));
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// shapex.from_extent_tensor
+//===----------------------------------------------------------------------===//
+
+static bool isValidTensorOfExtents(RankedTensorType type) {
+  // If the tensor of extents is not static shapes, that would imply that the
+  // tensor whose shape it is describing is unranked.
+  return type.getRank() == 1 && type.hasStaticShape();
+}
+
+LogicalResult FromExtentTensorOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto inputType = operands[0].getType().dyn_cast<RankedTensorType>();
+  if (!inputType || !isValidTensorOfExtents(inputType)) {
+    return failure();
+  }
+  SmallVector<int64_t, 6> extents(inputType.getDimSize(0),
+                                  static_cast<int64_t>(-1));
+  inferredReturnTypes.push_back(RankedShapeType::get(extents, context));
+  return success();
+}
+
+bool FromExtentTensorOp::isCompatibleReturnTypes(ArrayRef<Type> lhs,
+                                                 ArrayRef<Type> rhs) {
+  auto lhsRs = lhs[0].cast<RankedShapeType>();
+  auto rhsRs = rhs[0].cast<RankedShapeType>();
+  return succeeded(
+      verifyCompatibleShape(lhsRs.getAllDims(), rhsRs.getAllDims()));
+}
+
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
 
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.h b/iree/compiler/Dialect/Shape/IR/ShapeOps.h
index e684ddd..348c33d 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.h
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.h
@@ -21,6 +21,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffects.h"
 
 namespace mlir {
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
index 26ec1c8..5a4a57b 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
@@ -17,6 +17,7 @@
 
 include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/SideEffects.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/OpAsmInterface.td"
 
 //===----------------------------------------------------------------------===//
@@ -145,6 +146,53 @@
   }];
 }
 
+// TODO(silvasean): What if the shape is an error shape?
+def Shape_ToExtentTensorOp : Shape_PureOp<"to_extent_tensor",
+     [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Convert a ranked shape to a tensor of extents.";
+  let description = [{
+    Convert a !shapex.ranked_shape to a rank-1 tensor of integers.
+
+    Examples:
+    %t0 = "shapex.to_extent_tensor"(%rs0)
+      : (!shapex.ranked_shape<[3,?,5]>)
+      -> tensor<3xi32>
+    The resulting tensor will, for example, have elements [3,4,5] if the
+    dynamic dimension is 4 at runtime.
+  }];
+  let arguments = (ins Shape_RankedShape:$shape);
+  let results = (outs Shape_ExtentTensor:$extent_tensor);
+
+  // TODO: Custom parser/printer
+  let parser = ?;
+  let printer = ?;
+}
+
+def Shape_FromExtentTensorOp : Shape_PureOp<"from_extent_tensor",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Convert a tensor of extents to a ranked shape.";
+  let description = [{
+    Convert a rank-1 tensor of integers to a !shapex.ranked_shape.
+
+    Examples:
+    %t0 = "shapex.from_dimension_tensor"(%rs0)
+      : (tensor<3xi32>)
+      -> !shapex.ranked_shape<[?,?,?],i32>
+  }];
+  let arguments = (ins Shape_ExtentTensor:$extent_tensor);
+  let results = (outs Shape_RankedShape:$shape);
+
+  let hasCanonicalizer = 1;
+  let extraClassDeclaration = [{
+    // Declaration for overridden method from InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
+  }];
+
+  // TODO: Custom parser/printer
+  let parser = ?;
+  let printer = ?;
+}
+
 def Shape_ConstRankedShapeOp : Shape_PureOp<"const_ranked_shape",
     [ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
   let summary = "A constant ranked_shape.";
@@ -315,4 +363,85 @@
   let printer = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// Shape manipulations.
+//===----------------------------------------------------------------------===//
+
+def Shape_GatherExtentsOp : Shape_PureOp<"gather_extents",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Gather extents across shapes.";
+  let description = [{
+    Gathers extents across the !shapex.ranked_shape's in `shapes`.
+
+    This op conceptually performs the following operation:
+    1. The extents of all shapes in `shapes` are concatenated together into
+       a single list.
+    2. The resulting shape is constructed by extracting extents from the
+       combined list according to `indices`.
+    In pseudocode:
+    ```
+    shapes = ... # a list of lists of extents
+    # Example: shapes = [[3,-1],[2,7]]
+    extents = [extent for extent in shape for shape in shapes]
+    # Example: extents = [3,-1,2,7]
+    # or to use another terminology: `extents = flatmap(shapes)`
+    results = [extents[index] for index in indices]
+    ```
+
+    A large class of shape manipulations can be canonicalized into this op,
+    including:
+    - taking slices of shapes
+    - concatenating shapes
+    - permuting shapes
+    The intuition behind this op is that eventually each extent will be
+    exploded into its own SSA value. At which point, this op merely becomes
+    and identification of each SSA value of the output extents with an
+    SSA value of the input extents.
+    This op has the useful property that is closed under composition with
+    itself, thus allowing an arbitrarily complex subgraph consisting of just
+    this op to be folded together.
+
+    Some examples of shape transfer functions captured with this op:
+
+    - Taking the last two extents of a shape:
+      - [d0,d1,d2,d3] indices=[2,3] -> [d2,d3]
+    - Concatenating three shapes:
+      - [d0,d1] [d2,d3] [d4,d5] indices=[0,1,2,3,4,5] -> [d0,d1,d2,d3,d4,d5]
+    - Shape transfer function for transpose with permutation [0,2,1]:
+      - [d0,d1,d2] indices=[0,2,1] -> [d0,d2,d1]
+    - Shape transfer function for outer product of a vector with itself:
+      - Initial state: [d0] [d1] indices=[0,1] -> [d0,d1]
+      - Canonicalized to a single-operand op after observing that both inputs
+        are the same !shapex.ranked_shape value: [d0] indices=[0,0] -> [d0,d0]
+    - Shape transfer function for matmul with a batch dimension on the LHS:
+      - [d0,d1,d2] [d4,d5] indices=[0,1,2,5] -> [d0,d1,d2,d5]
+
+    This op is somewhat inspired by the LLVM `shufflevector` instruction.
+
+    Possible future pretty syntax for single-arg case:
+    %rs = shapex.gather_extents %0[0,2,1] : !shapex.ranked_shape<[5,6,7]>
+    Consider a pretty syntax for "concat":
+    %rs = shapex.gather_extents concat(%0, %1) : !shapex.ranked_shape<[5,6,7]>, !shapex.ranked_shape<[8,9]>
+
+  }];
+  let arguments = (ins
+    Variadic<Shape_RankedShape>:$shapes,
+    I64ElementsAttr:$indices
+  );
+  let results = (outs
+    Shape_RankedShape:$result
+  );
+
+  let verifier = [{ return verify$cppClass(*this); }];
+
+  // TODO: Custom parser/printer
+  let parser = ?;
+  let printer = ?;
+
+  let extraClassDeclaration = [{
+    static SmallVector<int64_t, 6> getConcatenatedExtents(ValueRange values);
+  }];
+}
+
+
 #endif  // IREE_DIALECT_SHAPE_OPS
diff --git a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
index e0b0c1b..7f5395b 100644
--- a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
@@ -41,3 +41,10 @@
   %0 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[2,4]> -> index
   return
 }
+
+// -----
+
+func @compatible_from_extent_tensor(%arg0: tensor<1xindex>) {
+  %0 = "shapex.from_extent_tensor"(%arg0) : (tensor<1xindex>) -> !shapex.ranked_shape<[3]>
+  return
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
index 11c9039..1531b4b 100644
--- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
@@ -143,6 +143,22 @@
   }
 };
 
+class ConvertDynamicBroadcastInDim
+    : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    xla_hlo::DynamicBroadcastInDimOpOperandAdaptor adapter(operands);
+    Value rankedShape = rewriter.create<Shape::FromExtentTensorOp>(
+        op.getLoc(), adapter.output_dimensions());
+    rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
+        op, op.getType(), adapter.operand(), rankedShape,
+        op.broadcast_dimensions());
+    return success();
+  }
+};
+
 class ConvertHLOToShapePass
     : public PassWrapper<ConvertHLOToShapePass, FunctionPass> {
   void runOnFunction() override {
@@ -153,6 +169,9 @@
     conversionTarget.addLegalDialect<StandardOpsDialect>();
     conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
 
+    conversionTarget.addIllegalOp<xla_hlo::DynamicBroadcastInDimOp>();
+    conversionPatterns.insert<ConvertDynamicBroadcastInDim>(&getContext());
+
 #define CONVERT_BINARY_ELEMENTWISE_OP(HloOpTy)                             \
   conversionTarget.addDynamicallyLegalOp<HloOpTy>(                         \
       [](HloOpTy op) { return IsSameRankedTypeBinaryElementwiseOp(op); }); \
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index f73c3b9..894f746 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -103,6 +103,44 @@
                                            filteredResultExtents);
 }
 
+LogicalResult expandGatherExtentsOp(GatherExtentsOp op,
+                                    GatherExtentsOp::OperandAdaptor operands,
+                                    PatternRewriter &rewriter) {
+  // Calculate cumulative sums of the ranks of each operand, which allows
+  // us to map each index to its corresponding operand easily.
+  SmallVector<int64_t, 6> cumsum;
+  cumsum.push_back(0);
+  for (auto operand : operands.shapes()) {
+    auto rank = operand.getType().cast<Shape::RankedShapeType>().getRank();
+    cumsum.push_back(cumsum.back() + rank);
+  }
+
+  // For each index, extract the relevant extent from the operands.
+  SmallVector<Value, 6> extents;
+  for (auto index : op.indices().getValues<int64_t>()) {
+    auto it = llvm::upper_bound(cumsum, index) - 1;
+    auto operandNum = std::distance(cumsum.begin(), it);
+    auto dimNum = index - *it;
+    auto extent = rewriter.create<Shape::RankedDimOp>(
+        op.getLoc(), operands.shapes()[operandNum], dimNum);
+    extents.push_back(extent);
+  }
+
+  // Due to a quirk of MakeRankedShapeOp, we only want the dynamic
+  // dimensions.
+  SmallVector<Value, 6> onlyDynamicExtents;
+  auto resultType = op.result().getType().cast<Shape::RankedShapeType>();
+  for (int i = 0, e = resultType.getRank(); i < e; i++) {
+    if (resultType.isDimDynamic(i)) {
+      onlyDynamicExtents.push_back(extents[i]);
+    }
+  }
+
+  rewriter.replaceOpWithNewOp<Shape::MakeRankedShapeOp>(op, resultType,
+                                                        onlyDynamicExtents);
+  return success();
+}
+
 LogicalResult expandRankedBroadcastShapePattern(
     RankedBroadcastShapeOp bcastOp,
     RankedBroadcastShapeOp::OperandAdaptor operands,
@@ -234,6 +272,7 @@
   // We explicitly want to convert these ops, eliminating them.
   target.addIllegalOp<GetRankedShapeOp>();
   target.addIllegalOp<RankedBroadcastShapeOp>();
+  target.addIllegalOp<GatherExtentsOp>();
 }
 
 void populateMaterializeShapeCalculationsConversionPatterns(
@@ -241,6 +280,8 @@
   // Fallback patterns.
   insertConversionPattern(patterns, context, expandRankedBroadcastShapePattern,
                           /*benefit=*/1);
+  insertConversionPattern(patterns, context, expandGatherExtentsOp,
+                          /*benefit=*/1);
   insertConversionPattern(patterns, context, materializeRankedShapePattern,
                           /*benefit=*/1);
 
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
index 96bdd6b..6c59951 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
@@ -67,6 +67,7 @@
     GetRankedShapeOp::getCanonicalizationPatterns(patterns, context);
     MakeRankedShapeOp::getCanonicalizationPatterns(patterns, context);
     RankedDimOp::getCanonicalizationPatterns(patterns, context);
+    RankedDimsOp::getCanonicalizationPatterns(patterns, context);
     TieShapeOp::getCanonicalizationPatterns(patterns, context);
     applyPatternsAndFoldGreedily(getOperation(), patterns);
   }
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir b/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
index 9819e1f..369325c 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
@@ -14,3 +14,14 @@
   // CHECK-DAG: return %[[SUM]]
   return %0 : tensor<?x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
+  // CHECK-DAG: %[[SHAPE:.+]] = "shapex.from_extent_tensor"(%arg1) : (tensor<2xindex>) -> !shapex.ranked_shape<[?,?]>
+  // CHECK-DAG: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: return %[[BROADCASTED]]
+  %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[1]> : tensor<1xi64>}: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
index 3fbc213..8a1d1aa 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
@@ -50,3 +50,38 @@
   %2 = shapex.get_ranked_shape %1 : tensor<?x2xf32> -> !shapex.ranked_shape<[?,2]>
   return %1, %2 : tensor<?x2xf32>, !shapex.ranked_shape<[?,2]>
 }
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index) -> (index, index, index) {
+  %0 = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[?,?]>
+  %1 = "shapex.gather_extents"(%0) {indices = dense<[1, 1, 0]> : tensor<3xi64>} : (!shapex.ranked_shape<[?,?]>) -> !shapex.ranked_shape<[?,?,?]>
+  %2:3 = shapex.ranked_dims %1 : !shapex.ranked_shape<[?,?,?]> -> index, index, index
+  // CHECK: return %arg1, %arg1, %arg0
+  return %2#0, %2#1, %2#2 : index, index, index
+}
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = shapex.make_ranked_shape %arg0 : (index) -> !shapex.ranked_shape<[?]>
+  %1 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[?]>
+  %2 = shapex.make_ranked_shape %arg2 : (index) -> !shapex.ranked_shape<[?]>
+  %gathered = "shapex.gather_extents"(%0, %1, %2) {indices = dense<[2, 2, 1, 0]> : tensor<4xi64>} : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?,?,?,?]>
+  %extents:4 = shapex.ranked_dims %gathered : !shapex.ranked_shape<[?,?,?,?]> -> index, index, index, index
+  // CHECK: return %arg2, %arg2, %arg1, %arg0
+  return %extents#0, %extents#1, %extents#2, %extents#3 : index, index, index, index
+}
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index, index, index) {
+  %0 = shapex.make_ranked_shape %arg0 : (index) -> !shapex.ranked_shape<[?,2]>
+  %1 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[3,?]>
+  %2 = shapex.make_ranked_shape %arg2 : (index) -> !shapex.ranked_shape<[7,?]>
+  %gathered = "shapex.gather_extents"(%0, %1, %2) {indices = dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi64>} : (!shapex.ranked_shape<[?,2]>, !shapex.ranked_shape<[3,?]>, !shapex.ranked_shape<[7,?]>) -> !shapex.ranked_shape<[?,2,3,?,7,?]>
+  %extents:6 = shapex.ranked_dims %gathered : !shapex.ranked_shape<[?,2,3,?,7,?]> -> index, index, index, index, index, index
+  // CHECK: return %arg0, %c2{{.*}}, %c3{{.*}}, %arg1, %c7{{.*}}, %arg2
+  return %extents#0, %extents#1, %extents#2, %extents#3, %extents#4, %extents#5 : index, index, index, index, index, index
+}
+
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index c780e14..08a5c4e 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -100,6 +100,7 @@
         "@llvm-project//mlir:SDBM",
         "@llvm-project//mlir:SPIRVDialect",
         "@llvm-project//mlir:SPIRVLowering",
+        "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardToSPIRVConversions",
         "@llvm-project//mlir:Transforms",
@@ -152,6 +153,7 @@
         "//iree/compiler/Dialect/HAL/Transforms",
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Dialect/IREE/Transforms",
+        "//iree/compiler/Dialect/Shape/Conversion",
         "//iree/compiler/Dialect/Shape/IR",
         "//iree/compiler/Dialect/Shape/Transforms",
         "//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/init_dialects.h b/iree/tools/init_dialects.h
index a91128d..3ce37dc 100644
--- a/iree/tools/init_dialects.h
+++ b/iree/tools/init_dialects.h
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SDBM/SDBMDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Dialect.h"
@@ -51,6 +52,7 @@
     registerDialect<linalg::LinalgDialect>();
     registerDialect<loop::LoopOpsDialect>();
     registerDialect<quant::QuantizationDialect>();
+    registerDialect<shape::ShapeDialect>();
     registerDialect<spirv::SPIRVDialect>();
     registerDialect<StandardOpsDialect>();
     registerDialect<vector::VectorDialect>();