[vmla] Dynamic shape lowering for dot_general

- Cleanly separate VMLA/Transforms/PreConversionLowering.cpp which does preparatory tensor->tensor rewrites.

- restructure conversion so xla_hlo.broadcast_in_dim and xla_hlo.dynamic_broadcast_in_dim are lowered via shapex.ranked_broadcast_in_dim

- add shape transfer function plugin for VMLA (only vmla.batch.matmul.pseudo for now); Also fixup docs on vmla.batch.matmul.pseudo.

- shape transfer function for xla_hlo.dynamic_reshape

- augment matrix_ops_test.py to test the cases all the way from Python!!! This is the end of the dynamic matmul! (though there are still some major issues, like error handling :/)

Also, fix up CMake build for convert-shape-to-shapex pass.

PiperOrigin-RevId: 309129583
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index ea38b2a..10293da 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -49,12 +49,26 @@
     return tf.matmul(lhs, rhs)
 
   @tf.function(input_signature=[
-      tf.TensorSpec([1, 7, 4, 2], tf.float32),
-      tf.TensorSpec([7, 1, 2, 4], tf.float32),
+      tf.TensorSpec([None, None, 4, 2], tf.float32),
+      tf.TensorSpec([None, None, 2, 4], tf.float32),
   ])
   def matmul_high_rank_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
 
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, None], tf.float32),
+      tf.TensorSpec([None, None, None], tf.float32),
+  ])
+  def matmul_dynamic(self, lhs, rhs):
+    return tf.matmul(lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, None], tf.float32),
+      tf.TensorSpec([None, None], tf.float32),
+  ])
+  def matmul_dynamic_lhs_batch(self, lhs, rhs):
+    return tf.matmul(lhs, rhs)
+
 
 @tf_test_utils.compile_modules(
     backends=["tf", "iree_vmla"], mat=MatrixOpsModule)
@@ -63,31 +77,55 @@
   def test_basic_matmul(self):
     m = self.modules.mat.all
     dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
-    dst.print().assert_all_close()
+    dst.assert_all_close()
 
   def test_matmul_lhs_batch(self):
     m = self.modules.mat.all
     dst = m.matmul_lhs_batch(
         tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
-    dst.print().assert_all_close()
+    dst.assert_all_close()
 
   def test_matmul_rhs_batch(self):
     m = self.modules.mat.all
     dst = m.matmul_rhs_batch(
         tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
-    dst.print().assert_all_close()
+    dst.assert_all_close()
 
   def test_matmul_broadcast_singleton_dimension(self):
     m = self.modules.mat.all
     dst = m.matmul_broadcast_singleton_dimension(
         tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
-    dst.print().assert_all_close()
+    dst.assert_all_close()
 
   def test_matmul_high_rank_batch(self):
     m = self.modules.mat.all
     dst = m.matmul_high_rank_batch(
         tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
-    dst.print().assert_all_close()
+    dst.assert_all_close()
+
+  def test_matmul_dynamic_matching_batch(self):
+    m = self.modules.mat.all
+    dst = m.matmul_dynamic(
+        tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
+    dst.assert_all_close()
+
+  def test_matmul_dynamic_broadcast_lhs(self):
+    m = self.modules.mat.all
+    dst = m.matmul_dynamic(
+        tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
+    dst.assert_all_close()
+
+  def test_matmul_dynamic_broadcast_rhs(self):
+    m = self.modules.mat.all
+    dst = m.matmul_dynamic(
+        tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
+    dst.assert_all_close()
+
+  def test_matmul_dynamic_rank_broadcasting(self):
+    m = self.modules.mat.all
+    dst = m.matmul_dynamic_lhs_batch(
+        tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
+    dst.assert_all_close()
 
 
 if __name__ == "__main__":
diff --git a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..85d92cd
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
@@ -0,0 +1,42 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    ConvertShapeToShapex
+  SRCS
+    "ConvertShapeToShapex.cpp"
+  DEPS
+    MLIRDialect
+    MLIRIR
+    MLIRPass
+    MLIRShape
+    MLIRTransforms
+    iree::compiler::Dialect::Shape::IR
+  ALWAYSLINK
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    Conversion
+  HDRS
+    "Passes.h"
+  DEPS
+    ::ConvertShapeToShapex
+    MLIRPass
+  PUBLIC
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/Passes.h b/iree/compiler/Dialect/Shape/Conversion/Passes.h
index 750bb86..6104d40 100644
--- a/iree/compiler/Dialect/Shape/Conversion/Passes.h
+++ b/iree/compiler/Dialect/Shape/Conversion/Passes.h
@@ -23,6 +23,14 @@
 // Convert `shape` dialect to `shapex` dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass();
 
+namespace Shape {
+
+inline void registerShapeConversionPasses() {
+  createConvertShapeToShapexPass();
+}
+
+}  // namespace Shape
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
new file mode 100644
index 0000000..559e5c0
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
@@ -0,0 +1,34 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "VMLAShapeBuilder",
+    srcs = [
+        "VMLAShapeBuilder.cpp",
+    ],
+    hdrs = [
+        "VMLAShapeBuilder.h",
+    ],
+    deps = [
+        "//iree/compiler/Dialect/Shape/IR",
+        "//iree/compiler/Dialect/VMLA/IR",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+    ],
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
new file mode 100644
index 0000000..33e2494
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    VMLAShapeBuilder
+  HDRS
+    "VMLAShapeBuilder.h"
+  SRCS
+    "VMLAShapeBuilder.cpp"
+  DEPS
+    LLVMSupport
+    MLIRIR
+    iree::compiler::Dialect::Shape::IR
+    iree::compiler::Dialect::VMLA::IR
+  PUBLIC
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
new file mode 100644
index 0000000..ad71a21
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
@@ -0,0 +1,68 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
+
+#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Optional.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir::iree_compiler::Shape;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+namespace {
+
+Value rewriteBatchMatMulPseudoOp(RankedShapeType resultShape,
+                                 BatchMatMulPseudoOp op, OpBuilder &builder) {
+  auto lhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.lhs());
+  auto rhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.rhs());
+  SmallVector<Value, 6> extents;
+  // Batch dimension (already been established to match between both operands,
+  // so arbitrarily use the LHS).
+  extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 0));
+  // RHS free dimension.
+  extents.push_back(builder.create<RankedDimOp>(op.getLoc(), rhsShape, 1));
+  // LHS free dimension.
+  extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 1));
+  // Due to a quirk of MakeRankedShapeOp, we only pass in the dynamic dims.
+  // So prune them down here.
+  SmallVector<Value, 6> onlyDynamicExtents;
+  for (int i = 0; i < 3; i++) {
+    if (resultShape.isDimDynamic(i)) {
+      onlyDynamicExtents.push_back(extents[i]);
+    }
+  }
+  return builder.create<MakeRankedShapeOp>(op.getLoc(), resultShape,
+                                           onlyDynamicExtents);
+}
+
+}  // namespace
+
+void populateVMLACustomOpShapeBuilder(CustomOpShapeBuilderList &builders) {
+  auto &b = builders.make<CallbackCustomOpShapeBuilder>();
+  b.insertOpRankedShapeBuilder<BatchMatMulPseudoOp>(rewriteBatchMatMulPseudoOp);
+}
+
+}  // namespace VMLA
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
new file mode 100644
index 0000000..a80de01
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
@@ -0,0 +1,34 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
+#define IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+// Creates a custom op shape builder for VMLA ops that are not otherwise
+// supported through traits or other declarative means.
+void populateVMLACustomOpShapeBuilder(
+    iree_compiler::Shape::CustomOpShapeBuilderList &builders);
+
+}  // namespace VMLA
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
new file mode 100644
index 0000000..14281d1
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
new file mode 100644
index 0000000..70ba1cb
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-shape-materialize-calculations %s | IreeFileCheck %s
+
+// -----
+// CHECK-LABEL: func @batch.matmul.pseudo
+func @batch.matmul.pseudo(
+  %lhs: tensor<?x?x?xf32>, %rhs: tensor<?x?x?xf32>,
+  %lhsShape: !shapex.ranked_shape<[?,?,?]>, %rhsShape: !shapex.ranked_shape<[?,?,?]>
+) -> !shapex.ranked_shape<[?,?,?]> {
+  %lhsTied = shapex.tie_shape %lhs, %lhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
+  %rhsTied = shapex.tie_shape %rhs, %rhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
+  %0 = "vmla.batch.matmul.pseudo"(%lhsTied, %rhsTied) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK-DAG: %[[BATCH:.+]] = shapex.ranked_dim %arg2[0]
+  // CHECK-DAG: %[[FLHS:.+]] = shapex.ranked_dim %arg2[1]
+  // CHECK-DAG: %[[FRHS:.+]] = shapex.ranked_dim %arg3[1]
+  // CHECK-DAG: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[BATCH]], %[[FRHS]], %[[FLHS]]
+  // CHECK-DAG: return %[[SHAPE]]
+  %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
+  return %1 : !shapex.ranked_shape<[?,?,?]>
+}
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index f20127f..142cccc 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -277,6 +277,12 @@
   return builder.create<MakeRankedShapeOp>(loc, resultShape, outputExtents);
 }
 
+Value rewriteDynamicReshape(RankedShapeType resultShape, DynamicReshapeOp op,
+                            OpBuilder &builder) {
+  return builder.create<FromExtentTensorOp>(op.getLoc(), resultShape,
+                                            op.output_shape());
+}
+
 }  // namespace
 
 // Creates a custom op shape builder for XLA-HLO ops that are not otherwise
@@ -306,6 +312,8 @@
   b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
   b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
   b.insertOpRankedShapeBuilder<xla_hlo::DotGeneralOp>(rewriteDotGeneral);
+  b.insertOpRankedShapeBuilder<xla_hlo::DynamicReshapeOp>(
+      rewriteDynamicReshape);
 }
 
 }  // namespace xla_hlo
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
index 2200907..59882e5 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
@@ -35,3 +35,14 @@
   %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
   return %1 : !shapex.ranked_shape<[?,?,?]>
 }
+
+// -----
+
+// CHECK-LABEL: func @dynamic_reshape
+func @dynamic_reshape(%arg0: tensor<?xf32>, %arg1: tensor<2xindex>) -> !shapex.ranked_shape<[?,?]> {
+  // CHECK-DAG: %[[SHAPE:.+]] = "shapex.from_extent_tensor"(%arg1)
+  // CHECK-DAG: return %[[SHAPE]]
+  %0 = "xla_hlo.dynamic_reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  %1 = shapex.get_ranked_shape %0 : tensor<?x?xf32> -> !shapex.ranked_shape<[?,?]>
+  return %1 : !shapex.ranked_shape<[?,?]>
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
index e389397..7de0280 100644
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -34,6 +34,7 @@
     ],
     deps = [
         "//iree/compiler/Dialect/Shape/IR",
+        "//iree/compiler/Dialect/Shape/Plugins/VMLA:VMLAShapeBuilder",
         "//iree/compiler/Dialect/Shape/Plugins/XLA:XlaHloShapeBuilder",
         "//iree/compiler/Dialect/Shape/Utils:TypeConversion",
         "//iree/compiler/Utils",
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
index d4aa7c8..0a6ee0a 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -37,6 +37,7 @@
     MLIRSupport
     MLIRTransforms
     iree::compiler::Dialect::Shape::IR
+    iree::compiler::Dialect::Shape::Plugins::VMLA::VMLAShapeBuilder
     iree::compiler::Dialect::Shape::Plugins::XLA::XlaHloShapeBuilder
     iree::compiler::Dialect::Shape::Utils::TypeConversion
     iree::compiler::Utils
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index 894f746..739e899 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -17,6 +17,7 @@
 #include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
 #include "iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
 #include "iree/compiler/Utils/PatternUtils.h"
@@ -47,6 +48,7 @@
   static CustomOpShapeBuilderList globalBuilders = ([]() {
     CustomOpShapeBuilderList builders;
     xla_hlo::populateXlaHloCustomOpShapeBuilder(builders);
+    IREE::VMLA::populateVMLACustomOpShapeBuilder(builders);
     return builders;
   })();
   return &globalBuilders;
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
index 8d7a8a6..14fab6d 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
@@ -21,7 +21,6 @@
     name = "HLOToVMLA",
     srcs = [
         "ConvertConvOps.cpp",
-        "ConvertDotOps.cpp",
         "ConvertHLOToVMLA.cpp",
         "ConvertReductionOps.cpp",
     ],
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
index 0da3cca..b47ddce 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
@@ -21,7 +21,6 @@
     "ConvertHLOToVMLA.h"
   SRCS
     "ConvertConvOps.cpp"
-    "ConvertDotOps.cpp"
     "ConvertHLOToVMLA.cpp"
     "ConvertReductionOps.cpp"
   DEPS
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp
deleted file mode 100644
index cfbc07f..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitVector.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Convert instances of `xla_hlo.dot` to `xla_hlo.dot_general`.
-//
-// TODO(silvasean): This logically is part of a future HLO client -> HLO server
-// type of pass in the xla_hlo dialect proper.
-struct CanonicalizeDotOp : public OpRewritePattern<xla_hlo::DotOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(xla_hlo::DotOp op,
-                                PatternRewriter &rewriter) const override {
-    Value lhs = op.lhs();
-    Value rhs = op.rhs();
-    RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
-    RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
-    if (!lhsType || !rhsType) {
-      return failure();
-    }
-    if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
-      return failure();
-    }
-    // TODO(silvasean): Move this helper to MLIR core.
-    auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
-      auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
-                                        rewriter.getIntegerType(64));
-      return DenseIntElementsAttr::get(type, integers);
-    };
-    auto dimensionNumbers = xla_hlo::DotDimensionNumbers::get(
-        /*lhs_batching_dimensions=*/make1DElementsAttr({}),
-        /*rhs_batching_dimensions=*/make1DElementsAttr({}),
-        /*lhs_contracting_dimensions=*/make1DElementsAttr({1}),
-        /*rhs_contracting_dimensions=*/make1DElementsAttr({0}),
-        rewriter.getContext());
-    rewriter.replaceOpWithNewOp<xla_hlo::DotGeneralOp>(
-        op, op.getType(), lhs, rhs, dimensionNumbers,
-        op.precision_config().hasValue() ? op.precision_config().getValue()
-                                         : nullptr);
-    return success();
-  }
-};
-
-// Inserts transposes on the operands of DotGeneralOp's such that the resulting
-// batch dimensions are all the leading dimensions and all the contracting
-// dimensions are all the trailing dimensions.
-//
-// Furthermore, all batch, contracting, and free dimensions are flattened into
-// single dimensions, with an appropriate reshape back to the original
-// dimensions.
-//
-// This results in a very simple corresponding VMLA op in the runtime.
-// [1 batch dimension, 1 free dimension, 1 contracting dimension].
-//
-// The result doesn't have a DotGeneralOp, but rather a
-// VMLA::BatchMatMulPseudoOp which represents this transformation.
-//
-// TODO(silvasean): Move this to a "prepare" pass and test separately.
-struct CanonicalizeDotGeneralOp
-    : public OpRewritePattern<xla_hlo::DotGeneralOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(xla_hlo::DotGeneralOp op,
-                                PatternRewriter &rewriter) const override {
-    Value lhs = op.lhs();
-    Value rhs = op.rhs();
-    RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
-    RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
-    Type elementType = lhsType.getElementType();
-    if (!lhsType || !rhsType) {
-      return failure();
-    }
-    // TODO(silvasean): Extend to support dynamic shapes.
-    // This op is a really good case for testing our e2e dynamic shape support.
-    // There's interesting questions at the TF2XLA level too.
-    if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) {
-      emitWarning(op.getLoc())
-          << "unimplemented: vmla dynamic shapes for xla_hlo.dot_general";
-      return failure();
-    }
-    xla_hlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
-    auto extract1DVector = [](DenseIntElementsAttr elements) {
-      SmallVector<int64_t, 6> ret;
-      for (const APInt &element : elements) {
-        ret.push_back(element.getLimitedValue());
-      }
-      return ret;
-    };
-    auto lhsBatchingDims =
-        extract1DVector(dimNumbers.lhs_batching_dimensions());
-    auto rhsBatchingDims =
-        extract1DVector(dimNumbers.rhs_batching_dimensions());
-    auto lhsContractingDims =
-        extract1DVector(dimNumbers.lhs_contracting_dimensions());
-    auto rhsContractingDims =
-        extract1DVector(dimNumbers.rhs_contracting_dimensions());
-    // TODO(silvasean): Move this helper to MLIR core.
-    auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
-      auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
-                                        rewriter.getIntegerType(64));
-      return DenseIntElementsAttr::get(type, integers);
-    };
-    auto totalElements = [](ArrayRef<int64_t> extents) {
-      int64_t numElements = 1;
-      for (int64_t extent : extents) {
-        numElements *= extent;
-      }
-      return numElements;
-    };
-    auto handleOneSide = [&](ArrayRef<int64_t> batchingDims,
-                             ArrayRef<int64_t> contractingDims, Value &value,
-                             RankedTensorType &type,
-                             SmallVectorImpl<int64_t> &outFreeDims,
-                             SmallVectorImpl<int64_t> &outFreeDimExtents,
-                             SmallVectorImpl<int64_t> &outBatchingDimExtents) {
-      outBatchingDimExtents.clear();
-      RankedTensorType untransposedType = type;
-      SmallVector<int64_t, 6> permutation;
-      llvm::BitVector freeDims(untransposedType.getRank(), true);
-      SmallVector<int64_t, 6> contractingDimExtents;
-      for (auto dims : {batchingDims, contractingDims}) {
-        for (int64_t dim : dims) {
-          freeDims.reset(dim);
-        }
-      }
-      for (int64_t dim : batchingDims) {
-        permutation.push_back(dim);
-        outBatchingDimExtents.push_back(untransposedType.getDimSize(dim));
-      }
-      for (int64_t dim : freeDims.set_bits()) {
-        permutation.push_back(dim);
-        outFreeDims.push_back(dim);
-        outFreeDimExtents.push_back(untransposedType.getDimSize(dim));
-      }
-      for (int64_t dim : contractingDims) {
-        permutation.push_back(dim);
-        contractingDimExtents.push_back(untransposedType.getDimSize(dim));
-      }
-      // Construct the type that the transpose will result in.
-      SmallVector<int64_t, 6> transposeShape;
-      for (int64_t index : permutation) {
-        transposeShape.push_back(type.getDimSize(index));
-      }
-      auto transposeType = RankedTensorType::get(transposeShape, elementType);
-      auto transpose = rewriter.create<xla_hlo::TransposeOp>(
-          op.getLoc(), transposeType, value, make1DElementsAttr(permutation));
-
-      auto reshapeType =
-          RankedTensorType::get({totalElements(outBatchingDimExtents),
-                                 totalElements(outFreeDimExtents),
-                                 totalElements(contractingDimExtents)},
-                                elementType);
-      value = rewriter.create<xla_hlo::ReshapeOp>(op.getLoc(), reshapeType,
-                                                  transpose);
-    };
-    SmallVector<int64_t, 6> batchingDimExtents;
-    SmallVector<int64_t, 6> lhsFreeDims;
-    SmallVector<int64_t, 6> lhsFreeDimExtents;
-    handleOneSide(lhsBatchingDims, lhsContractingDims, lhs, lhsType,
-                  lhsFreeDims, lhsFreeDimExtents, batchingDimExtents);
-    SmallVector<int64_t, 6> rhsFreeDims;
-    SmallVector<int64_t, 6> rhsFreeDimExtents;
-    handleOneSide(rhsBatchingDims, rhsContractingDims, rhs, rhsType,
-                  rhsFreeDims, rhsFreeDimExtents, batchingDimExtents);
-
-    auto dstShape = llvm::to_vector<6>(llvm::makeArrayRef(
-        {totalElements(batchingDimExtents), totalElements(rhsFreeDimExtents),
-         totalElements(lhsFreeDimExtents)}));
-    auto dstType = RankedTensorType::get(dstShape, elementType);
-    Value dst = rewriter.create<IREE::VMLA::BatchMatMulPseudoOp>(
-        op.getLoc(), dstType, lhs, rhs);
-    RankedTensorType transposeType = RankedTensorType::get(
-        {dstShape[0], dstShape[2], dstShape[1]}, elementType);
-    auto transpose = rewriter.create<xla_hlo::TransposeOp>(
-        op.getLoc(), transposeType, dst, make1DElementsAttr({0, 2, 1}));
-    auto reshapeShape = batchingDimExtents;
-    reshapeShape.append(lhsFreeDimExtents.begin(), lhsFreeDimExtents.end());
-    reshapeShape.append(rhsFreeDimExtents.begin(), rhsFreeDimExtents.end());
-    auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
-    rewriter.replaceOpWithNewOp<xla_hlo::ReshapeOp>(op, reshapeType, transpose);
-    return success();
-  }
-};
-
-}  // namespace
-
-void populateHLODotToVMLAPatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns,
-                                  TypeConverter &typeConverter) {
-  // Tensor-level preparation for lowering to the runtime BatchMatMul op.
-  patterns.insert<CanonicalizeDotGeneralOp>(context);
-  patterns.insert<CanonicalizeDotOp>(context);
-
-  // Lowering from tensor ops to VMLA runtime ops.
-  patterns.insert<VMLAOpConversion<IREE::VMLA::BatchMatMulPseudoOp,
-                                   IREE::VMLA::BatchMatMulOp>>(context,
-                                                               typeConverter);
-}
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 1d2ba93..4aafac6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -60,7 +60,10 @@
   LogicalResult matchAndRewrite(
       SRC srcOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    if (srcOp.getOperand().hasOneUse()) {
+    // xla_hlo::DynamicReshape has multiple operands, so we cannot just say
+    // `getOperand()`. But `getOperand(0)` doesn't work for the other
+    // single-operand ops. So use the raw Operation to get the operand.
+    if (srcOp.getOperation()->getOperand(0).hasOneUse()) {
       // Can directly pass through the input buffer as we don't need to clone
       // for other users.
       rewriter.replaceOp(srcOp, operands[0]);
@@ -76,15 +79,22 @@
   }
 };
 
-// Converts a broadcast_in_dim op to either a broadcast or a tile depending on
-// the input shape.
+// Converts a shapex.ranked_broadcast_in_dim op to either a broadcast or a tile
+// depending on the input shape.
+//
+// We assume that xla_hlo.broadcast_in_dim and xla_hlo.dynamic_broadcast_in_dim
+// have been legalized into that op.
+//
+// Note that shapex.ranked_broadcast_in_dim is not strictly speaking an HLO op,
+// but we would like HLO to eventually have something like it, and the shapex
+// dialect is currently where we have it stuffed.
 struct BroadcastInDimOpConversion
-    : public OpConversionPattern<xla_hlo::BroadcastInDimOp> {
+    : public OpConversionPattern<Shape::RankedBroadcastInDimOp> {
   BroadcastInDimOpConversion(MLIRContext *context, TypeConverter &typeConverter)
       : OpConversionPattern(context), typeConverter(typeConverter) {}
 
   LogicalResult matchAndRewrite(
-      xla_hlo::BroadcastInDimOp srcOp, ArrayRef<Value> operands,
+      Shape::RankedBroadcastInDimOp srcOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     auto srcShape = VMLAConversionTarget::getTensorShape(
         srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
@@ -454,8 +464,10 @@
   // xla_hlo.reduce and xla_hlo.reduce_window.
   populateHLOReductionToVMLAPatterns(context, patterns, typeConverter);
 
-  // xla_hlo.dot and xla_hlo.dot_general.
-  populateHLODotToVMLAPatterns(context, patterns, typeConverter);
+  // vmla.batch.matmul.pseudo
+  patterns.insert<VMLAOpConversion<IREE::VMLA::BatchMatMulPseudoOp,
+                                   IREE::VMLA::BatchMatMulOp>>(context,
+                                                               typeConverter);
 
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
@@ -534,6 +546,7 @@
   patterns.insert<IdentityOpConversion<xla_hlo::BitcastConvertOp>>(context);
   patterns.insert<IdentityOpConversion<xla_hlo::CopyOp>>(context);
   patterns.insert<IdentityOpConversion<xla_hlo::ReshapeOp>>(context);
+  patterns.insert<IdentityOpConversion<xla_hlo::DynamicReshapeOp>>(context);
 
   // Conversions that don't have a 1:1 mapping, mostly involving buffer views
   // or transfers.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
index 44b0a3d..aaccde9 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -2,6 +2,7 @@
 
 // CHECK-LABEL: @broadcast_in_dim_2D_3D
 func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> attributes { sym_visibility = "private" } {
+  %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
   %input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
   // CHECK-DAG: %[[SRC:.+]] = "vmla.constant"
   // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2,4]>
@@ -9,7 +10,7 @@
   // CHECK-DAG: %[[DST_SIZE:.+]] = constant 96 : index
   // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
   // CHECK-DAG: "vmla.tile"(%[[SRC]], %[[SRC_SHAPE]], %[[DST]], %[[DST_SHAPE]]) {element_type = i32}
-  %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+  %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
   // CHECK-NEXT: return %[[DST]] : !vmla.buffer
   return %0 : tensor<3x2x4xi32>
 }
@@ -18,6 +19,7 @@
 
 // CHECK-LABEL: @broadcast_in_dim_3D_scalar
 func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> attributes { sym_visibility = "private" } {
+  %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
   // CHECK-DAG: %[[SRC:.+]] = "vmla.constant"
   // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
   // CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
@@ -25,7 +27,7 @@
   %input = constant dense<42> : tensor<i32>
   // CHECK-NEXT: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
   // CHECK-NEXT: "vmla.broadcast"(%[[SRC]], %[[SRC_SHAPE]], %[[DST]], %[[DST_SHAPE]]) {element_type = i32}
-  %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<3x2x4xi32>
+  %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
   // CHECK-NEXT: return %[[DST]] : !vmla.buffer
   return %0 : tensor<3x2x4xi32>
 }
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir
deleted file mode 100644
index d7fc32b..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// -----
-
-func @f(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> attributes {sym_visibility = "private"} {
-  // CHECK: vmla.batch.matmul
-  %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {
-    lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
-    lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
-    rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
-    rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>
-  }} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
-  return %0 : tensor<3x5xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 5bfa951..8cefc91 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -379,7 +379,7 @@
     All operands are rank-3 with the following dimension structure:
     - lhs = [B, FLHS, C]
     - rhs = [B, FRHS, C]
-    - dst = [B, FLHS, FRHS]
+    - dst = [B, FRHS, FLHS]
     Where:
     - B = batch dimension
     - C = contracting dimension
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index 2981f40..7ba9688 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -22,6 +22,7 @@
     srcs = [
         "Conversion.cpp",
         "Passes.cpp",
+        "PreConversionLowering.cpp",
         "UnrollReductions.cpp",
     ],
     hdrs = [
@@ -37,6 +38,7 @@
         "//iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA",
         "//iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA",
         "//iree/compiler/Dialect/VMLA/IR",
+        "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
index e980b46..80fad09 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@
   SRCS
     "Conversion.cpp"
     "Passes.cpp"
+    "PreConversionLowering.cpp"
     "UnrollReductions.cpp"
   DEPS
     LLVMSupport
@@ -39,6 +40,7 @@
     iree::compiler::Dialect::VMLA::Conversion::HLOToVMLA
     iree::compiler::Dialect::VMLA::Conversion::StandardToVMLA
     iree::compiler::Dialect::VMLA::IR
+    iree::compiler::Dialect::VMLA::IR::VMLADialect
     tensorflow::mlir_xla
   PUBLIC
 )
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index b9cf4dc..00f806e 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -99,6 +99,21 @@
         });
     conversionTarget.addIllegalOp<Shape::RankedDimOp>();
     conversionTarget.addIllegalOp<Shape::RankedDimsOp>();
+    // XLA ops use tensors of extents, so we tend to launder back to
+    // !shapex.ranked_shape for most shape-related things. This is a problem
+    // because we don't have a lowering for the ops going back and forth between
+    // tensors of extents and !shapex.ranked_shape. So we mark this op as
+    // illegal and rely on our fold of `from_extent_tensor(to_extent_tensor(x))
+    // -> x` to eliminate these ops. Setting it illegal here triggers that fold.
+    // This is skating on thin ice.
+    // TODO(silvasean): Legalize ToExtentTensorOp and FromExtentTensorOp.
+    conversionTarget.addIllegalOp<Shape::FromExtentTensorOp>();
+    // RankedBroadcastInDimOp is an logically something that should be an
+    // xla_hlo op (or in a dialect at a similar level of abstraction), but since
+    // it isn't technically in that dialect, we need to special-case mark it as
+    // illegal here.
+    // TODO(silvasean): Reconcile the dialect layering here.
+    conversionTarget.addIllegalOp<Shape::RankedBroadcastInDimOp>();
 
     if (failed(applyPartialConversion(getOperation(), conversionTarget,
                                       conversionPatterns, &typeConverter))) {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index e0393f1..dc4cc6c 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -43,8 +43,10 @@
 
   // ---------------------------------------------------------------------------
   // Tensor-level rewrites.
-  // At this point, the computation is in tensor-level CFG form with the ability
-  // perform transformations that alter shapes.
+  // At this point, the computation is in tensor-level CFG form.
+  // There are no specific requirements on shape-related calculations at this
+  // point yet, so general tensor->tensor transformations in preparation
+  // for later conversion steps should go here.
   // ---------------------------------------------------------------------------
   // Legalize input types.
   // TODO(benvanik): legalize input.
@@ -56,6 +58,12 @@
   // Unroll multi-dimensional reductions to one reduction per dimension.
   passManager.addNestedPass<FuncOp>(createUnrollReductionsPass());
 
+  // Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
+  passManager.addNestedPass<FuncOp>(createPreConversionLoweringPass());
+
+  // Clean up the IR before going into shape-materialized IR.
+  passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+
   // ---------------------------------------------------------------------------
   // Shape calculation.
   // Pre-conditions:
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.h b/iree/compiler/Dialect/VMLA/Transforms/Passes.h
index dbb413e..7d7b74d 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.h
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.h
@@ -52,6 +52,9 @@
 // dimension, from innermost to outermost.
 std::unique_ptr<OperationPass<FuncOp>> createUnrollReductionsPass();
 
+// Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
+std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass();
+
 //===----------------------------------------------------------------------===//
 // Dialect conversion
 //===----------------------------------------------------------------------===//
@@ -66,6 +69,7 @@
 inline void registerVMLAPasses() {
   createUnrollReductionsPass();
   createConversionPass();
+  createPreConversionLoweringPass();
 }
 
 }  // namespace VMLA
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
new file mode 100644
index 0000000..610bd6c
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -0,0 +1,320 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitVector.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+
+namespace {
+
+// Convert instances of `xla_hlo.dot` to `xla_hlo.dot_general`.
+//
+// TODO(silvasean): This logically is part of a future HLO client -> HLO server
+// type of pass in the xla_hlo dialect proper.
+struct LowerDotOp : public OpRewritePattern<xla_hlo::DotOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(xla_hlo::DotOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.lhs();
+    Value rhs = op.rhs();
+    RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
+    RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
+    if (!lhsType || !rhsType) {
+      return failure();
+    }
+    if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
+      return failure();
+    }
+    // TODO(silvasean): Move this helper to MLIR core.
+    auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
+      auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+                                        rewriter.getIntegerType(64));
+      return DenseIntElementsAttr::get(type, integers);
+    };
+    auto dimensionNumbers = xla_hlo::DotDimensionNumbers::get(
+        /*lhs_batching_dimensions=*/make1DElementsAttr({}),
+        /*rhs_batching_dimensions=*/make1DElementsAttr({}),
+        /*lhs_contracting_dimensions=*/make1DElementsAttr({1}),
+        /*rhs_contracting_dimensions=*/make1DElementsAttr({0}),
+        rewriter.getContext());
+    rewriter.replaceOpWithNewOp<xla_hlo::DotGeneralOp>(
+        op, op.getType(), lhs, rhs, dimensionNumbers,
+        op.precision_config().hasValue() ? op.precision_config().getValue()
+                                         : nullptr);
+    return success();
+  }
+};
+
+// Inserts transposes on the operands of DotGeneralOp's such that the resulting
+// batch dimensions are all the leading dimensions and all the contracting
+// dimensions are all the trailing dimensions.
+//
+// Furthermore, all batch, contracting, and free dimensions are flattened into
+// single dimensions, with an appropriate reshape back to the original
+// dimensions.
+//
+// This results in a very simple corresponding VMLA op in the runtime.
+// [1 batch dimension, 1 free dimension, 1 contracting dimension].
+//
+// The result doesn't have a DotGeneralOp, but rather a
+// VMLA::BatchMatMulPseudoOp which represents this transformation.
+//
+// TODO(silvasean): Move this to a "prepare" pass and test separately.
+struct LowerDotGeneralOp : public OpRewritePattern<xla_hlo::DotGeneralOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(xla_hlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.lhs();
+    Value rhs = op.rhs();
+    RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
+    RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
+    Type elementType = lhsType.getElementType();
+    if (!lhsType || !rhsType) {
+      return rewriter.notifyMatchFailure(op, "requires ranked types");
+    }
+    xla_hlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+    auto extract1DVector = [](DenseIntElementsAttr elements) {
+      SmallVector<int64_t, 6> ret;
+      for (const APInt &element : elements) {
+        ret.push_back(element.getLimitedValue());
+      }
+      return ret;
+    };
+    auto lhsBatchingDims =
+        extract1DVector(dimNumbers.lhs_batching_dimensions());
+    auto rhsBatchingDims =
+        extract1DVector(dimNumbers.rhs_batching_dimensions());
+    auto lhsContractingDims =
+        extract1DVector(dimNumbers.lhs_contracting_dimensions());
+    auto rhsContractingDims =
+        extract1DVector(dimNumbers.rhs_contracting_dimensions());
+    // TODO(silvasean): Move this helper to MLIR core.
+    auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
+      auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+                                        rewriter.getIntegerType(64));
+      return DenseIntElementsAttr::get(type, integers);
+    };
+    auto totalElements = [&](ArrayRef<Value> extents) {
+      Value numElements = rewriter.create<mlir::ConstantOp>(
+          op.getLoc(), IntegerAttr::get(rewriter.getIndexType(), 1));
+      for (Value extent : extents) {
+        numElements =
+            rewriter.create<mlir::MulIOp>(op.getLoc(), numElements, extent);
+      }
+      return numElements;
+    };
+    auto handleOneSide = [&](ArrayRef<int64_t> batchingDims,
+                             ArrayRef<int64_t> contractingDims, Value &value,
+                             RankedTensorType &type,
+                             SmallVectorImpl<int64_t> &outFreeDims,
+                             SmallVectorImpl<Value> &outFreeDimExtents,
+                             SmallVectorImpl<Value> &outBatchingDimExtents) {
+      outBatchingDimExtents.clear();
+      RankedTensorType untransposedType = type;
+      SmallVector<int64_t, 6> permutation;
+      llvm::BitVector freeDims(untransposedType.getRank(), true);
+      SmallVector<Value, 6> contractingDimExtents;
+      Value valueShape =
+          rewriter.create<Shape::GetRankedShapeOp>(op.getLoc(), value);
+      auto getExtentValue = [&](int64_t dim) {
+        return rewriter.create<Shape::RankedDimOp>(op.getLoc(), valueShape,
+                                                   dim);
+      };
+      for (auto dims : {batchingDims, contractingDims}) {
+        for (int64_t dim : dims) {
+          freeDims.reset(dim);
+        }
+      }
+      for (int64_t dim : batchingDims) {
+        permutation.push_back(dim);
+        outBatchingDimExtents.push_back(getExtentValue(dim));
+      }
+      for (int64_t dim : freeDims.set_bits()) {
+        permutation.push_back(dim);
+        outFreeDims.push_back(dim);
+        outFreeDimExtents.push_back(getExtentValue(dim));
+      }
+      for (int64_t dim : contractingDims) {
+        permutation.push_back(dim);
+        contractingDimExtents.push_back(getExtentValue(dim));
+      }
+      // Construct the type that the transpose will result in.
+      SmallVector<int64_t, 6> transposeStaticShape;
+      for (int64_t index : permutation) {
+        (void)index;
+        transposeStaticShape.push_back(-1);
+      }
+      auto transposeType =
+          RankedTensorType::get(transposeStaticShape, elementType);
+      auto transpose = rewriter.create<xla_hlo::TransposeOp>(
+          op.getLoc(), transposeType, value, make1DElementsAttr(permutation));
+
+      SmallVector<Value, 6> reshapeShape;
+      reshapeShape.push_back(totalElements(outBatchingDimExtents));
+      reshapeShape.push_back(totalElements(outFreeDimExtents));
+      reshapeShape.push_back(totalElements(contractingDimExtents));
+      auto reshapeType = RankedTensorType::get(
+          {static_cast<int64_t>(-1), static_cast<int64_t>(-1),
+           static_cast<int64_t>(-1)},
+          elementType);
+      auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
+          op.getLoc(),
+          Shape::RankedShapeType::get(reshapeType.getShape(),
+                                      rewriter.getContext()),
+          reshapeShape);
+      auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
+          op.getLoc(), reshapeRankedShape);
+      value = rewriter.create<xla_hlo::DynamicReshapeOp>(
+          op.getLoc(), reshapeType, transpose, reshapeShapeExtentTensor);
+    };
+    SmallVector<Value, 6> batchingDimExtents;
+    SmallVector<int64_t, 6> lhsFreeDims;
+    SmallVector<Value, 6> lhsFreeDimExtents;
+    handleOneSide(lhsBatchingDims, lhsContractingDims, lhs, lhsType,
+                  lhsFreeDims, lhsFreeDimExtents, batchingDimExtents);
+    SmallVector<int64_t, 6> rhsFreeDims;
+    SmallVector<Value, 6> rhsFreeDimExtents;
+    handleOneSide(rhsBatchingDims, rhsContractingDims, rhs, rhsType,
+                  rhsFreeDims, rhsFreeDimExtents, batchingDimExtents);
+
+    auto dstStaticShape = llvm::to_vector<6>(
+        llvm::makeArrayRef({static_cast<int64_t>(-1), static_cast<int64_t>(-1),
+                            static_cast<int64_t>(-1)}));
+    auto dstType = RankedTensorType::get(dstStaticShape, elementType);
+    Value dst = rewriter.create<IREE::VMLA::BatchMatMulPseudoOp>(
+        op.getLoc(), dstType, lhs, rhs);
+    RankedTensorType transposeType = RankedTensorType::get(
+        {dstStaticShape[0], dstStaticShape[2], dstStaticShape[1]}, elementType);
+    auto transpose = rewriter.create<xla_hlo::TransposeOp>(
+        op.getLoc(), transposeType, dst, make1DElementsAttr({0, 2, 1}));
+    auto reshapeShape = batchingDimExtents;
+    reshapeShape.append(lhsFreeDimExtents.begin(), lhsFreeDimExtents.end());
+    reshapeShape.append(rhsFreeDimExtents.begin(), rhsFreeDimExtents.end());
+    SmallVector<int64_t, 6> reshapeStaticShape;
+    for (int i = 0, e = batchingDimExtents.size() + lhsFreeDimExtents.size() +
+                        rhsFreeDimExtents.size();
+         i < e; i++) {
+      reshapeStaticShape.push_back(-1);
+    }
+    auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
+        op.getLoc(),
+        Shape::RankedShapeType::get(reshapeStaticShape, rewriter.getContext()),
+        reshapeShape);
+    auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
+        op.getLoc(), reshapeRankedShape);
+    rewriter.replaceOpWithNewOp<xla_hlo::DynamicReshapeOp>(
+        op, op.getType(), transpose, reshapeShapeExtentTensor);
+    return success();
+  }
+};
+
+class LowerBroadcastInDimOp
+    : public OpRewritePattern<xla_hlo::BroadcastInDimOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(xla_hlo::BroadcastInDimOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = op.getType().cast<RankedTensorType>();
+    auto shapeType =
+        Shape::RankedShapeType::get(type.getShape(), rewriter.getContext());
+    auto shape =
+        rewriter.create<Shape::ConstRankedShapeOp>(op.getLoc(), shapeType);
+    rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
+        op, op.getType(), op.operand(), shape, op.broadcast_dimensions());
+    return success();
+  }
+};
+
+// Lower xla_hlo::BroadcastOp via xla_hlo::BroadcastInDimOp.
+class LowerBroadcastOp : public OpRewritePattern<xla_hlo::BroadcastOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(xla_hlo::BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = op.getOperand().getType().cast<RankedTensorType>();
+    auto resultType = op.getType().cast<RankedTensorType>();
+    auto broadcastDimensions = llvm::to_vector<6>(llvm::seq<int64_t>(
+        resultType.getRank() - type.getRank(), resultType.getRank()));
+    rewriter.replaceOpWithNewOp<xla_hlo::BroadcastInDimOp>(
+        op, op.getType(), op.getOperand(),
+        rewriter.getI64TensorAttr(broadcastDimensions));
+    return success();
+  }
+};
+
+class PreConversionLoweringPass
+    : public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
+ public:
+  void runOnOperation() {
+    MLIRContext *context = &getContext();
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+    target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<IREE::VMLA::VMLADialect>();
+    target.addLegalDialect<xla_hlo::XlaHloDialect>();
+    target.addLegalDialect<ShapeDialect>();
+
+    target.addIllegalOp<xla_hlo::DotGeneralOp>();
+    patterns.insert<LowerDotGeneralOp>(context);
+    target.addIllegalOp<xla_hlo::DotOp>();
+    patterns.insert<LowerDotOp>(context);
+    target.addIllegalOp<xla_hlo::BroadcastInDimOp>();
+    patterns.insert<LowerBroadcastInDimOp>(context);
+    target.addIllegalOp<xla_hlo::BroadcastOp>();
+    patterns.insert<LowerBroadcastOp>(context);
+
+    if (failed(applyPartialConversion(getOperation(), target, patterns))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+static PassRegistration<PreConversionLoweringPass> pass(
+    "iree-vmla-pre-conversion-lowering",
+    "Tensor-level pattern-based lowerings.");
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass() {
+  return std::make_unique<PreConversionLoweringPass>();
+}
+
+}  // namespace VMLA
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
new file mode 100644
index 0000000..ecd94f7
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -0,0 +1,33 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering %s | IreeFileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
+  // CHECK: vmla.batch.matmul
+  %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {
+    lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+    lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
+    rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+    rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>
+  }} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
+  // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
+  return %0 : tensor<4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>) -> tensor<5x6x3xf32> {
+  // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs5_6_3)
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[5, 6]> : tensor<2xi64>} : (tensor<3xf32>) -> tensor<5x6x3xf32>
+  return %0 : tensor<5x6x3xf32>
+}
diff --git a/iree/test/e2e/regression/dynamic_dot_general.mlir b/iree/test/e2e/regression/dynamic_dot_general.mlir
new file mode 100644
index 0000000..488136f
--- /dev/null
+++ b/iree/test/e2e/regression/dynamic_dot_general.mlir
@@ -0,0 +1,46 @@
+// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -input-value="2x2xf32=[[1.0, 0.0], [0.0, 1.0]]" -input-value="2x3xf32=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]" -input-value="2x2x2xf32=[[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]]" -input-value="2x2x3xf32=[[[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]" | IreeFileCheck %s
+
+// TODO(silvasean): Extent xla_ops directory test infra to support
+// testing dynamic shapes.
+
+// CHECK-LABEL: EXEC @basic_dot
+func @basic_dot(
+  %lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>,
+  %unused0: tensor<?x?x?xf32>, %unused1: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+  %0 = "xla_hlo.dot_general"(%lhs, %rhs) {dot_dimension_numbers={
+    lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+    lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
+    rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+    rhs_contracting_dimensions = dense<0> : tensor<1xi64>
+  }} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: 2x3xf32=[1 2 3][4 5 6]
+
+// CHECK-LABEL: EXEC @batch_dimension
+func @batch_dimension(
+  %unused0: tensor<?x?xf32>, %unused1: tensor<?x?xf32>,
+  %lhs: tensor<?x?x?xf32>, %rhs: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = "xla_hlo.dot_general"(%lhs, %rhs) {dot_dimension_numbers={
+    lhs_batching_dimensions = dense<[0]> : tensor<1xi64>,
+    lhs_contracting_dimensions = dense<[2]> : tensor<1xi64>,
+    rhs_batching_dimensions = dense<[0]> : tensor<1xi64>,
+    rhs_contracting_dimensions = dense<[1]> : tensor<1xi64>
+  }} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK: 2x2x3xf32=[
+// CHECK-SAME:   [1.5 2.5 3.5][4.5 5.5 6.5]
+// CHECK-SAME: ][
+// CHECK-SAME:   [2 4 6][8 10 12]
+// CHECK-SAME: ]
+
+
+// TODO(silvasean): Add more tests when we have better test infra.
+// This is currently too verbose / unreadable. We should test:
+// - multiple contracting dimensions
+// - multiple batch dimensions
+// - multiple free dimensions
+// - intermingled batch, free, and contracting dimensions
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 08a5c4e..13a6ac7 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -74,6 +74,7 @@
         "//iree/compiler/Dialect/HAL/Transforms",
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Dialect/IREE/Transforms",
+        "//iree/compiler/Dialect/Shape/Conversion",
         "//iree/compiler/Dialect/Shape/IR",
         "//iree/compiler/Dialect/Shape/Transforms",
         "//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 7ba48ce..32f51ea 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -138,6 +138,7 @@
       iree::compiler::Dialect::HAL::Transforms
       iree::compiler::Dialect::IREE::IR
       iree::compiler::Dialect::IREE::Transforms
+      iree::compiler::Dialect::Shape::Conversion
       iree::compiler::Dialect::Shape::IR
       iree::compiler::Dialect::Shape::Transforms
       iree::compiler::Dialect::VM::Analysis
diff --git a/iree/tools/init_passes.h b/iree/tools/init_passes.h
index 1b5631f..ee0ac73 100644
--- a/iree/tools/init_passes.h
+++ b/iree/tools/init_passes.h
@@ -26,6 +26,7 @@
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
 #include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
 #include "iree/compiler/Dialect/VM/Transforms/Passes.h"
@@ -175,6 +176,7 @@
   IREE::Flow::registerFlowAnalysisTestPasses();
   IREE::HAL::registerHALPasses();
   IREE::registerIreePasses();
+  Shape::registerShapeConversionPasses();
   Shape::registerShapePasses();
   IREE::VM::registerVMPasses();
   IREE::VM::registerVMAnalysisTestPasses();