Merge pull request #4571 from google/staging-hal-rewrite

diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index c6836a5..e619cd0 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,16 +4,16 @@
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
 b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
-39e3ba5be1a67c1e1ed900a5bc8134d0c8374d73 third_party/llvm-bazel
-f3f3c9c2549a268e602be8730990b552e30cc932 third_party/llvm-project
+77871f43e449ad492bf8b94dee453670ac15e158 third_party/llvm-bazel
+b92a39ac1319c796777bca19a3af2856acbc69c1 third_party/llvm-project
 4e501d8c6e2d834999301a2492adefe5ddbdc0cb third_party/mlir-emitc
-471fc63c11205639dab25345aea1f85831ef4cb9 third_party/mlir-hlo
+2b72ddc6b2b4d670bcd1ffa3f4652468b419f986 third_party/mlir-hlo
 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft
 d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
 2887692065c38ef6617f423feafc6b69dd0a0681 third_party/ruy
 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
-8f595c955848e24b3edca51f78a68943e3c26f50 third_party/tensorflow
+16613a70ef36b103e7c1ffa903d541814b62c109 third_party/tensorflow
 9c3dac3ed2bd647b8d63f197fed058fee97a7e1e third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/bindings/python/pyiree/jax/README.md b/bindings/python/pyiree/jax/README.md
index d7bf18c..b6b6542 100644
--- a/bindings/python/pyiree/jax/README.md
+++ b/bindings/python/pyiree/jax/README.md
@@ -44,11 +44,10 @@
 
 Install the Android NDK according to the
 [Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake)
-doc, and then ensure the following environment variables are set:
+doc, and then ensure the following environment variable is set:
 
 ```shell
 export ANDROID_NDK=# NDK install location
-export IREE_LLVMAOT_LINKER_PATH="${ANDROID_NDK?}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang++"
 ```
 
 The code below assumes that you have `flax` installed.
@@ -62,6 +61,10 @@
 import flax
 from flax import linen as nn
 
+import os
+# Configure the linker to target Android.
+os.environ["IREE_LLVMAOT_LINKER_PATH"] = "${ANDROID_NDK?}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang++ -static-libstdc++"
+
 
 class MLP(nn.Module):
 
diff --git a/bindings/python/pyiree/jax/frontend.py b/bindings/python/pyiree/jax/frontend.py
index d49edfe..0429f36 100644
--- a/bindings/python/pyiree/jax/frontend.py
+++ b/bindings/python/pyiree/jax/frontend.py
@@ -83,9 +83,9 @@
     self._options = options
     self._memoized_signatures = {}
 
-  def _get_signature(self, args_flat):
+  def _get_signature(self, args_flat, in_tree):
     args_flat = [rt.normalize_value(arg) for arg in args_flat]
-    return tuple((arg.shape, arg.dtype) for arg in args_flat)
+    return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
 
   def _wrap_and_compile(self, signature, args_flat, in_tree):
     """Compiles the function for the given signature."""
@@ -110,7 +110,7 @@
   def _get_compiled_artifacts(self, args, kwargs):
     """Returns the binary, loaded rt module and out_tree."""
     args_flat, in_tree = jax.tree_flatten((args, kwargs))
-    signature = self._get_signature(args_flat)
+    signature = self._get_signature(args_flat, in_tree)
 
     if signature not in self._memoized_signatures:
       self._wrap_and_compile(signature, args_flat, in_tree)
diff --git a/bindings/python/pyiree/jax/frontend_test.py b/bindings/python/pyiree/jax/frontend_test.py
index dbc59de..153976d 100644
--- a/bindings/python/pyiree/jax/frontend_test.py
+++ b/bindings/python/pyiree/jax/frontend_test.py
@@ -31,6 +31,40 @@
   return np.random.normal(0, 1, shape).astype(np.float32)
 
 
+class SqrtNode:
+
+  def __init__(self, x, y):
+    self.x = x
+    self.y = y
+
+  def apply(self, z):
+    return self.x * jnp.sqrt(self.y * z)
+
+  def tree_flatten(self):
+    return ((self.x, self.y), None)
+
+  @classmethod
+  def tree_unflatten(cls, aux_data, children):
+    return cls(*children)
+
+
+class SquareNode:
+
+  def __init__(self, x, y):
+    self.x = x
+    self.y = y
+
+  def apply(self, z):
+    return self.x * (self.y * z)**2
+
+  def tree_flatten(self):
+    return ((self.x, self.y), None)
+
+  @classmethod
+  def tree_unflatten(cls, aux_data, children):
+    return cls(*children)
+
+
 class JAXFrontendTest(unittest.TestCase):
 
   def test_aot_pytree(self):
@@ -161,6 +195,22 @@
 
     self.assertEqual(add_sqrt_four(2), 4)
 
+  def test_jit_pytree_method(self):
+
+    @iree.jax.jit
+    def apply_node(node, z):
+      return node.apply(z)
+
+    expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
+    compied_sqrt = apply_node(SqrtNode(2, 3), 4)
+    np.testing.assert_allclose(compied_sqrt, expected_sqrt)
+
+    expected_square = apply_node._function(SquareNode(2, 3), 4)
+    compied_square = apply_node(SquareNode(2, 3), 4)
+    np.testing.assert_allclose(expected_square, expected_square)
+
 
 if __name__ == "__main__":
+  jax.tree_util.register_pytree_node_class(SqrtNode)
+  jax.tree_util.register_pytree_node_class(SquareNode)
   unittest.main()
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index 5bcf2f5..7df27b2 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -36,7 +36,7 @@
 import os
 import sys
 
-from typing import Any, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple, Union
 
 from . import binding as _binding
 
@@ -133,19 +133,10 @@
   return _global_config
 
 
-def normalize_value(value: Any) -> Optional[np.ndarray]:
-  """Normalizes the given value for input to (or comparison with) IREE."""
-  if value is None:
-    # Exclude None from falling through to blanket np.asarray conversion.
-    return value
-
-  array = np.asarray(value)
-  if isinstance(value, (bool, int, float, list, tuple)):
-    # Manually convert ints and floats to 32 bits.
-    if array.dtype == np.float64:
-      array = array.astype(np.float32)
-    elif array.dtype == np.int64:
-      array = array.astype(np.int32)
+def _bool_to_int8(
+    array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
+  if not isinstance(array, np.ndarray):
+    return array
 
   # IREE models booleans as i8s.
   # TODO: This cast should be moved into the function abi. If it's possible to
@@ -153,6 +144,26 @@
   # type should also be recast to np.bool at that level.
   if array.dtype == np.bool:
     array = array.astype(np.int8)
+  return array
+
+
+def normalize_value(
+    value: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
+  """Normalizes the given value for input to (or comparison with) IREE."""
+  if value is None:
+    # Exclude None from falling through to blanket np.asarray conversion.
+    return value
+
+  if isinstance(value, (list, tuple)):
+    return value
+
+  array = np.asarray(value)
+  if isinstance(value, (bool, int, float)):
+    # Manually convert ints and floats to 32 bits.
+    if array.dtype == np.float64:
+      array = array.astype(np.float32)
+    elif array.dtype == np.int64:
+      array = array.astype(np.int32)
 
   return array
 
@@ -172,6 +183,8 @@
     # Convert tensors, device arrays, ints, ... to IREE-friendly inputs.
     args = [normalize_value(value) for value in args]
     kwargs = {k: normalize_value(v) for k, v in kwargs.items()}
+    args = [_bool_to_int8(value) for value in args]
+    kwargs = {k: _bool_to_int8(v) for k, v in kwargs.items()}
 
     # NOTE: This is just doing sync dispatch right now. In the future,
     # this should default to async and potentially have some kind of policy
diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py
index c8f335e..924acfc 100644
--- a/build_tools/github_actions/build_dist.py
+++ b/build_tools/github_actions/build_dist.py
@@ -47,9 +47,12 @@
 
   python ./main_checkout/build_tools/github_actions/build_dist.py main-dist
   python ./main_checkout/build_tools/github_actions/build_dist.py py-runtime-pkg
-  python ./main_checkout/build_tools/github_actions/build_dist.py py-xla-compiler-tools-pkg
-  python ./main_checkout/build_tools/github_actions/build_dist.py py-tflite-compiler-tools-pkg
-  python ./main_checkout/build_tools/github_actions/build_dist.py py-tf-compiler-tools-pkg
+  python ./main_checkout/build_tools/github_actions/build_dist.py
+  py-xla-compiler-tools-pkg
+  python ./main_checkout/build_tools/github_actions/build_dist.py
+  py-tflite-compiler-tools-pkg
+  python ./main_checkout/build_tools/github_actions/build_dist.py
+  py-tf-compiler-tools-pkg
 
 
 That is not a perfect approximation but is close.
@@ -81,14 +84,14 @@
 
 # Load version info.
 def load_version_info():
-  with open(os.path.join(IREESRC_DIR, 'version_info.json'), 'rt') as f:
+  with open(os.path.join(IREESRC_DIR, "version_info.json"), "rt") as f:
     return json.load(f)
 
 
 try:
   version_info = load_version_info()
 except FileNotFoundError:
-  print('version_info.json found. Using defaults')
+  print("version_info.json found. Using defaults")
   version_info = {
       "package-version": "0.1dev1",
       "package-suffix": "-dev",
diff --git a/docs/design_docs/codegen_passes.md b/docs/design_docs/codegen_passes.md
index 4a1335c..b36db1c 100644
--- a/docs/design_docs/codegen_passes.md
+++ b/docs/design_docs/codegen_passes.md
@@ -415,11 +415,12 @@
 while the former are to be executed collectively by workitems within a
 workgroup, the latter have to be executed by all workitems across workgroups.
 One way to distinguish these two operations is to use the marker mechanism in
-Linalg ([LinalgMarker][LinalgTilingPatterns]). This is a `StrAttr` whose value
-can be used to encode the scope of the operation. For example, in Snippet 7
-above, the tiled `linalg.matmul` operation has a marker `workgroup` to indicate
-that this operation needs to be executed by a workgroup in a collective manner.
-At this time, the code-generation pipeline uses only the `workgroup` marker.
+Linalg ([LinalgTransformationFilter][LinalgTilingPatterns]). This is a `StrAttr`
+whose value can be used to encode the scope of the operation. For example, in
+Snippet 7 above, the tiled `linalg.matmul` operation has a marker `workgroup` to
+indicate that this operation needs to be executed by a workgroup in a collective
+manner. At this time, the code-generation pipeline uses only the `workgroup`
+marker.
 
 __Roadmap Note__ : Markers are meant to be short-lived, ideally set and consumed
 within the same pass. In the current pipeline the lifetime spans passes to allow
diff --git a/docs/get_started/getting_started_android_cmake.md b/docs/get_started/getting_started_android_cmake.md
index 183dd3a..022443c 100644
--- a/docs/get_started/getting_started_android_cmake.md
+++ b/docs/get_started/getting_started_android_cmake.md
@@ -81,7 +81,7 @@
 ```shell
 $ cmake -G Ninja -B ../iree-build-android/ \
   -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK?}/build/cmake/android.toolchain.cmake" \
-  -DIREE_HOST_BINARY_ROOT=../iree-build-host/install \
+  -DIREE_HOST_BINARY_ROOT=$(realpath ../iree-build-host/install) \
   -DANDROID_ABI="arm64-v8a" \
   -DANDROID_PLATFORM=android-29 \
   -DIREE_BUILD_COMPILER=OFF \
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 0ddcfbe..02c5102 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -17,7 +17,7 @@
 #include "experimental/ModelBuilder/ModelRunner.h"
 #include "experimental/ModelBuilder/VulkanWrapperPass.h"
 #include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -250,8 +251,8 @@
   return a == b;
 }
 
-static MatmulCodegenStrategy createPowerVRStrategy(int tileM, int tileN,
-                                                   int tileK, int warpSize) {
+static linalg::CodegenStrategy createPowerVRStrategy(int tileM, int tileN,
+                                                     int tileK, int warpSize) {
   const std::array<int64_t, 3> nativeSize = {1, 1, 1};
   linalg::LinalgLoopDistributionOptions WIDistribute;
   linalg::LinalgLoopDistributionOptions WGDistribute;
@@ -274,7 +275,7 @@
                    b.create<ConstantIndexOp>(loc, 1)};
     return procInfo;
   };
-  MatmulCodegenStrategy strategy;
+  linalg::CodegenStrategy strategy;
   SmallVector<int64_t, 2> promotionList;
   // promote matrix B
   promotionList.push_back(1);
@@ -301,13 +302,16 @@
           .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
           .setTileSizes({1, tileN, tileK})
           .setDistributionOptions(WIDistribute));
-  strategy.vectorize<linalg::MatmulOp>().unrollVector<vector::ContractionOp>(
-      nativeSize);
+  strategy.vectorize<linalg::MatmulOp>()
+      // TODO: Upstream to core.
+      // .unrollVector<vector::ContractionOp>(nativeSize)
+      ;
+  (void)nativeSize;
   return strategy;
 }
 
-static MatmulCodegenStrategy createMaliStrategy(int tileM, int tileN, int tileK,
-                                                int warpSize) {
+static linalg::CodegenStrategy createMaliStrategy(int tileM, int tileN,
+                                                  int tileK, int warpSize) {
   const std::array<int64_t, 3> nativeSize = {1, 4, 1};
   linalg::LinalgLoopDistributionOptions WIDistribute;
   linalg::LinalgLoopDistributionOptions WGDistribute;
@@ -330,7 +334,7 @@
                    b.create<ConstantIndexOp>(loc, 1)};
     return procInfo;
   };
-  MatmulCodegenStrategy strategy;
+  linalg::CodegenStrategy strategy;
   strategy
       .tile<linalg::MatmulOp>(
           linalg::LinalgTilingOptions()
@@ -344,13 +348,16 @@
           .setTileSizes({tileM, tileN / warpSize, tileK})
           .setDistributionOptions(WIDistribute));
   strategy.vectorize<linalg::MatmulOp>()
-      .unrollVector<vector::TransferReadOp>({1, 4})
-      .unrollVector<vector::ContractionOp>(nativeSize);
+      // TODO: Upstream to core.
+      // .unrollVector<vector::TransferReadOp>({1, 4})
+      // .unrollVector<vector::ContractionOp>(nativeSize)
+      ;
+  (void)nativeSize;
   return strategy;
 }
 
-static MatmulCodegenStrategy createTuringStrategy(int tileM, int tileN,
-                                                  int tileK) {
+static linalg::CodegenStrategy createTuringStrategy(int tileM, int tileN,
+                                                    int tileK) {
   std::array<int64_t, 3> nativeSize;
   if (matType == "i8xi8xi32")
     nativeSize = {16, 16, 32};
@@ -372,7 +379,7 @@
       linalg::DistributionMethod::CyclicNumProcsEqNumIters};
   SGDistribute.procInfo = getSubgroupIds;
 
-  MatmulCodegenStrategy strategy;
+  linalg::CodegenStrategy strategy;
   strategy
       .tile<linalg::MatmulOp>(
           linalg::LinalgTilingOptions()
@@ -398,8 +405,11 @@
                     {tileM / numSubgroupY, tileN / numSubgroupX, tileK})
                 .setDistributionOptions(SGDistribute));
   }
-  strategy.vectorize<linalg::MatmulOp>().unrollVector<vector::ContractionOp>(
-      nativeSize);
+  strategy.vectorize<linalg::MatmulOp>()
+      // TODO: Upstream to core.
+      // .unrollVector<vector::ContractionOp>(nativeSize)
+      ;
+  (void)nativeSize;
   return strategy;
 }
 
@@ -449,7 +459,7 @@
                      ModelRunner::Target::GPUTarget);
   CompilationOptions options;
   options.loweringPasses = [&](mlir::PassManager &pm) {
-    MatmulCodegenStrategy strategy;
+    linalg::CodegenStrategy strategy;
 
     if (target == "powerVR") {
       strategy = createPowerVRStrategy(tileM, tileN, tileK, warpSize);
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 9848826..bbc0b63 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -19,33 +19,34 @@
 
 // clang-format on
 #include <string>
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/ExecutionEngine/CRunnerUtils.h"
-#include "mlir/ExecutionEngine/RunnerUtils.h"
+
 #include "experimental/ModelBuilder/ModelBuilder.h"
 #include "experimental/ModelBuilder/ModelRunner.h"
 #include "experimental/ModelBuilder/VulkanWrapperPass.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/InitLLVM.h"
+#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/ExecutionEngine/RunnerUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Parser.h"
-#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
-#include "mlir/Pass/PassManager.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
-#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
-#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
-#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
-#include "mlir/Dialect/GPU/Passes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
 
 using namespace mlir;                    // NOLINT
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index b39f7a6..2c9b478 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -236,11 +236,16 @@
 
   # Base check for numpy arrays.
   elif isinstance(ref, np.ndarray):
-    if ref.dtype != tar.dtype:
+    # Ignore np.bool != np.int8 because the IREE python runtime awkwardly
+    # returns np.int8s instead of np.bools.
+    if ref.dtype != tar.dtype and not (
+        (ref.dtype == np.bool and tar.dtype == np.int8) or
+        (ref.dtype == np.int8 and tar.dtype == np.bool)):
       error = ("Expected ref and tar to have the same dtype, but got "
                f"'{ref.dtype}' and '{tar.dtype}'")
       logging.error(error)
       return False, error
+
     if ref.size == tar.size == 0:
       return True, None
 
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 4b58a54..ea3c83c 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -27,14 +27,14 @@
         "FunctionUtils.cpp",
         "GetNumWorkgroups.cpp",
         "MarkerUtils.cpp",
-        "MatmulCodegenStrategy.cpp",
+        "TransformUtils.cpp",
     ],
     hdrs = [
         "ForOpCanonicalization.h",
         "FunctionUtils.h",
         "GetNumWorkgroups.h",
         "MarkerUtils.h",
-        "MatmulCodegenStrategy.h",
+        "TransformUtils.h",
     ],
     deps = [
         "//iree/compiler/Dialect/HAL/IR",
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index bf8ddd7..a4589fc 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -22,13 +22,13 @@
     "FunctionUtils.h"
     "GetNumWorkgroups.h"
     "MarkerUtils.h"
-    "MatmulCodegenStrategy.h"
+    "TransformUtils.h"
   SRCS
     "ForOpCanonicalization.cpp"
     "FunctionUtils.cpp"
     "GetNumWorkgroups.cpp"
     "MarkerUtils.cpp"
-    "MatmulCodegenStrategy.cpp"
+    "TransformUtils.cpp"
   DEPS
     LLVMSupport
     MLIRAffine
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
deleted file mode 100644
index b36f4f7..0000000
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
+++ /dev/null
@@ -1,286 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// -----------------------------------------------------------------------------
-// This code will be removed once this gets upstreamed to common mlir.
-// Please try to limit changes in this code only minor changes.
-
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
-
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/SCF/Utils.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopUtils.h"
-#include "mlir/Transforms/Passes.h"
-
-using namespace mlir;          // NOLINT
-using namespace mlir::linalg;  // NOLINT
-
-#define DEBUG_TYPE "matmul-codegen-strategy"
-
-//===----------------------------------------------------------------------===//
-// TODO: Cleanup and upstream these to go into core. Please ignore for now !
-//===----------------------------------------------------------------------===//
-static void hoistRedundantCopies(FuncOp func) {
-  bool changed = true;
-  while (changed) {
-    changed = false;
-    func.walk([&](linalg::FillOp op) {
-      auto loop = op->getParentOfType<scf::ForOp>();
-      if (!loop) return;
-
-      for (auto operand : op.getOperands())
-        if (!loop.isDefinedOutsideOfLoop(operand)) return;
-
-      // Hoist fill before.
-      op.getOperation()->moveBefore(loop);
-      changed = true;
-    });
-
-    func.walk([&](linalg::CopyOp op) {
-      auto loop = op->getParentOfType<scf::ForOp>();
-      if (!loop) return;
-
-      for (auto operand : op.getOperands())
-        if (!loop.isDefinedOutsideOfLoop(operand)) return;
-
-      Value sourceView = op.getInput(0);
-      while (auto subViewOp = sourceView.getDefiningOp<SubViewOp>())
-        sourceView = subViewOp.getViewSource();
-
-      // Source traces back to a block argument.
-      if (sourceView.isa<BlockArgument>()) {
-        op.getOperation()->moveBefore(loop);
-      } else {
-        assert(sourceView.getDefiningOp<ViewOp>() ||
-               sourceView.getDefiningOp<AllocOp>() ||
-               sourceView.getDefiningOp<AllocaOp>());
-        op.getOperation()->moveAfter(loop);
-      }
-      changed = true;
-    });
-  }
-}
-
-/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
-///   `%lb + %step * new_dim` where
-/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
-/// AffineDimExpr depending on whether the value is constant or not.
-/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
-/// AffineSymbolExpr depending on whether the value is constant or not.
-///
-static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
-                       SmallVectorImpl<Value> &dims,
-                       SmallVectorImpl<Value> &symbols) {
-  MLIRContext *ctx = forOp.getContext();
-  auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
-  AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
-                             : getAffineDimExpr(dims.size(), ctx);
-
-  auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
-  AffineExpr step = stepConstant
-                        ? getAffineConstantExpr(stepConstant.getValue(), ctx)
-                        : getAffineSymbolExpr(symbols.size(), ctx);
-
-  if (!lbConstant) dims.push_back(forOp.lowerBound());
-  if (!stepConstant) symbols.push_back(forOp.step());
-  exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
-
-  auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
-  AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
-                             : getAffineDimExpr(dims.size(), ctx);
-  if (!ubConstant) dims.push_back(forOp.upperBound());
-  exprs.push_back(ub);
-
-  dims.push_back(forOp.getInductionVar());
-}
-
-/// Substitue dimensions coming from forOp or AffineMin. Return false if it has
-/// unknown dimension operands.
-static bool substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
-                       SmallVectorImpl<Value> &dims,
-                       SmallVectorImpl<Value> &symbols) {
-  if (minOp.getDimOperands().empty()) return false;
-  for (Value v : minOp.getDimOperands()) {
-    if (auto forOp = scf::getForInductionVarOwner(v)) {
-      substitute(forOp, exprs, dims, symbols);
-      continue;
-    }
-    if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
-      substitute(parentMinOp, exprs, dims, symbols);
-      continue;
-    }
-    // If couldn't substitue the dimension give up and use the original map.
-    return false;
-  }
-  return true;
-}
-
-LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
-    AffineMinOp minOp, PatternRewriter &rewriter) const {
-  LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
-                          << *minOp.getOperation() << "\n");
-
-  int64_t min = std::numeric_limits<int64_t>::max();
-  for (auto e : minOp.map().getResults())
-    if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
-      min = std::min(min, cstExpr.getValue());
-  if (min == std::numeric_limits<int64_t>::max()) return failure();
-
-  MLIRContext *ctx = minOp.getContext();
-  AffineMap map;
-  SmallVector<Value, 4> operands;
-  SmallVector<AffineExpr, 4> exprs;
-  SmallVector<Value, 4> dims, symbols;
-  if (substitute(minOp, exprs, dims, symbols)) {
-    operands = dims;
-    operands.append(symbols.begin(), symbols.end());
-
-    map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
-    LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
-  } else {
-    map = minOp.getAffineMap();
-    operands = minOp.getDimOperands();
-    operands.append(minOp.getSymbolOperands().begin(),
-                    minOp.getSymbolOperands().end());
-  }
-  SmallVector<AffineExpr, 4> modExprs;
-  for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
-    modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
-  map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
-  canonicalizeMapAndOperands(&map, &operands);
-  map = simplifyAffineMap(map);
-
-  LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
-             llvm::interleaveComma(operands, llvm::dbgs()));
-
-  if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
-        if (auto cst = e.dyn_cast<AffineConstantExpr>())
-          return cst.getValue() == 0;
-        return false;
-      }))
-    return failure();
-
-  rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
-  return success();
-}
-//===----------------------------------------------------------------------===//
-// END TODO
-//===----------------------------------------------------------------------===//
-
-void MatmulCodegenStrategy::transform(FuncOp func) const {
-  MLIRContext *context = func.getContext();
-  // Emplace patterns one at a time while also maintaining a simple chained
-  // state transition.
-  unsigned stepCount = 0;
-  SmallVector<FrozenRewritePatternList, 4> stage1Patterns;
-  auto zeroState = Identifier::get(std::to_string(stepCount), context);
-  auto currentState = zeroState;
-  for (auto &t : transformationSequence) {
-    auto nextState = Identifier::get(std::to_string(++stepCount), context);
-    auto marker = (currentState == zeroState)
-                      ? linalg::LinalgMarker({}, nextState)
-                      : linalg::LinalgMarker(currentState, nextState);
-    stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
-    currentState = nextState;
-  }
-
-  OwningRewritePatternList stage2Patterns =
-      linalg::getLinalgTilingCanonicalizationPatterns(context);
-  // Add extra patterns to canonicalize AffineMin in combination with scf loops
-  // operations after tiling.
-  stage2Patterns.insert<AffineMinCanonicalizationPattern,
-                        AffineMinSCFCanonicalizationPattern>(context);
-
-  auto stage3Transforms = [](Operation *op) {
-    promoteSingleIterationLoops(cast<FuncOp>(op));
-    return success();
-  };
-  linalg::applyStagedPatterns(func, stage1Patterns, std::move(stage2Patterns),
-                              stage3Transforms);
-
-  auto postStageTransforms = [this](Operation *op) {
-    // Run LICM and hoisting patterns after all the stages as we want to
-    // unrolling before moving transfer ops out of the loop.
-    if (hoistInvariantCode) {
-      PassManager pm(op->getContext());
-      pm.addPass(createLoopInvariantCodeMotionPass());
-      if (failed(pm.run(op->getParentOfType<ModuleOp>())))
-        llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
-      hoistViewAllocOps(cast<FuncOp>(op));
-      hoistRedundantVectorTransfers(cast<FuncOp>(op));
-      hoistRedundantCopies(cast<FuncOp>(op));
-    }
-    OwningRewritePatternList patterns;
-    vector::populateVectorSlicesLoweringPatterns(patterns, op->getContext());
-    applyPatternsAndFoldGreedily(op, std::move(patterns));
-  };
-  postStageTransforms(func);
-  if (lowering != nullptr) lowering(func);
-}
-
-// Parametric lowering of vector contract for CPU target.
-static void cpuLowering(
-    FuncOp func, const vector::VectorTransformsOptions &vectorTransformsOptions,
-    const VectorTransferToSCFOptions &vectorToSCFOptions) {
-  // Programmatic controlled lowering of vector.contract only.
-  MLIRContext *context = func.getContext();
-  OwningRewritePatternList vectorContractLoweringPatterns;
-  vectorContractLoweringPatterns
-      .insert<ContractionOpToOuterProductOpLowering,
-              ContractionOpToMatmulOpLowering, ContractionOpLowering>(
-          vectorTransformsOptions, context);
-
-  applyPatternsAndFoldGreedily(func, std::move(vectorContractLoweringPatterns));
-
-  // Programmatic controlled lowering of vector.transfer only.
-  OwningRewritePatternList vectorToLoopsPatterns;
-  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
-                                        vectorToSCFOptions);
-  applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
-}
-
-MatmulCodegenStrategy &MatmulCodegenStrategy::setDefaultCPULowering() {
-  auto lowering = [this](FuncOp func) {
-    cpuLowering(func, vectorTransformsOptions, vectorToSCFOptions);
-  };
-  return setLoweringFunction(lowering);
-}
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
deleted file mode 100644
index e93f50b..0000000
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
+++ /dev/null
@@ -1,244 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MATMULCODEGENSTRATEGY_H_
-#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MATMULCODEGENSTRATEGY_H_
-
-#include <functional>
-
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-
-/// Abstract Transformation class applied in a sequence that also handles state
-/// through markers.
-struct Transformation {
-  virtual ~Transformation() = default;
-  virtual OwningRewritePatternList buildRewritePatterns(
-      MLIRContext *context, linalg::LinalgMarker m) = 0;
-  linalg::LinalgMarker marker;
-};
-
-template <typename VectorOpType>
-struct UnrollVector : public Transformation {
-  explicit UnrollVector(ArrayRef<int64_t> targetShape)
-      : targetShape(targetShape.begin(), targetShape.end()) {}
-
-  OwningRewritePatternList buildRewritePatterns(
-      MLIRContext *ctx, linalg::LinalgMarker m) override {
-    OwningRewritePatternList vectorUnrollPatterns;
-    vectorUnrollPatterns.insert<vector::UnrollVectorPattern>(
-        ctx, vector::UnrollVectorOptions()
-                 .setNativeShape(targetShape)
-                 .setFilterConstraint([](Operation *op) {
-                   return success(isa<VectorOpType>(op));
-                 }));
-    vector::populateVectorToVectorCanonicalizationPatterns(vectorUnrollPatterns,
-                                                           ctx);
-    vector::populateVectorToVectorTransformationPatterns(vectorUnrollPatterns,
-                                                         ctx);
-    return vectorUnrollPatterns;
-  }
-
- private:
-  SmallVector<int64_t, 4> targetShape;
-};
-
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Tile<LinalgOpType>`with the appropriate `options`.
-// TODO: variadic LinalgOpTypes.
-template <typename LinalgOpType>
-struct Tile : public Transformation {
-  explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
-
-  OwningRewritePatternList buildRewritePatterns(
-      MLIRContext *context, linalg::LinalgMarker m) override {
-    OwningRewritePatternList tilingPatterns;
-    tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
-        context, options, m);
-    return tilingPatterns;
-  }
-
- private:
-  linalg::LinalgTilingOptions options;
-};
-
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Promote<LinalgOpType>`with the appropriate `options`.
-// TODO: variadic LinalgOpTypes.
-template <typename LinalgOpType>
-struct Promote : public Transformation {
-  explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
-
-  OwningRewritePatternList buildRewritePatterns(
-      MLIRContext *context, linalg::LinalgMarker m) override {
-    OwningRewritePatternList promotionPatterns;
-    promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
-        context, options, m);
-    return promotionPatterns;
-  }
-
- private:
-  linalg::LinalgPromotionOptions options;
-};
-
-/// Vectorization transformation enqueues a particular stage-1 pattern for
-/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
-/// transfer rewrite forwarding patterns.
-// TODO: variadic LinalgOpTypes.
-template <typename LinalgOpType>
-struct Vectorize : public Transformation {
-  OwningRewritePatternList buildRewritePatterns(
-      MLIRContext *context, linalg::LinalgMarker m) override {
-    OwningRewritePatternList vectorizationPatterns;
-    // FillOp may interfere with forwarding patterns atm, so we bump up the
-    // priority of LinalgCopyVTRForwardingPattern /
-    // LinalgCopyVTWForwardingPattern.
-    vectorizationPatterns
-        .insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
-    vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
-                                 linalg::LinalgCopyVTWForwardingPattern>(
-        context, /*benefit=*/2);
-    return vectorizationPatterns;
-  }
-};
-
-/// Matmul-specific strategy object controls how a linalg.matmul is
-/// progressively lowered.
-/// The strategy uses a 3-level staged patterns strategy which allows ordering
-/// transformations by using the Linalg `applyStagedPatterns` function, where:
-///   1. The first stage consists of the successive `tile`, `promote` and
-///   `vectorize` patterns, applied sequentially.
-///   2. The second stage consists of common local canonicalization patterns
-///   that are applied eagerly after each stage-1 pattern.
-///   3. the third stage consists of more global transformation, also applied
-///   eagerly, after all stage-2 patterns. Such more global transformations
-struct MatmulCodegenStrategy {
-  /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
-  /// `options`.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &tile(linalg::LinalgTilingOptions options) {
-    transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
-    return *this;
-  }
-  /// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
-  /// with tiling `options`.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
-    return b ? tile<LinalgOpType>(options) : *this;
-  }
-  /// Append a pattern to add a level of promotion for `LinalgOpType` with
-  /// promotion `options`.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
-    transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
-    return *this;
-  }
-  /// Conditionally append a pattern to add a level of promotion for
-  /// `LinalgOpType` with promotion `options`.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &promoteIf(bool b,
-                                   linalg::LinalgPromotionOptions options) {
-    return b ? promote<LinalgOpType>(options) : *this;
-    return *this;
-  }
-  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &vectorize() {
-    transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
-    return *this;
-  }
-  /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
-  /// operation.
-  template <typename LinalgOpType>
-  MatmulCodegenStrategy &vectorizeIf(bool b) {
-    return b ? vectorize<LinalgOpType>() : *this;
-    return *this;
-  }
-  /// Configure the post staged-patterns late vector transformations.
-  MatmulCodegenStrategy &setVectorTransformsOptions(
-      vector::VectorTransformsOptions options) {
-    vectorTransformsOptions = options;
-    return *this;
-  }
-  /// Configure the post staged-patterns late vector.transfer to scf conversion.
-  MatmulCodegenStrategy &setVectorTransferToSCFOptions(
-      VectorTransferToSCFOptions options) {
-    vectorToSCFOptions = options;
-    return *this;
-  }
-  /// Configure the post staged-patterns late vector.transfer to scf conversion.
-  MatmulCodegenStrategy &setHoistInvariantCode(bool b) {
-    hoistInvariantCode = b;
-    return *this;
-  }
-
-  /// Apply the transformation patterns in sequence with cleanup transformations
-  /// interleaved.
-  void transform(FuncOp func) const;
-
-  /// Set a function applying the lowering strategy. Different target need to
-  /// use different lowering.
-  MatmulCodegenStrategy &setLoweringFunction(std::function<void(FuncOp)> f) {
-    lowering = f;
-    return *this;
-  }
-
-  /// Append a pattern to unroll a `VectorOpType` to smaller vector operations.
-  template <typename VectorOpType>
-  MatmulCodegenStrategy &unrollVector(ArrayRef<int64_t> targetShape) {
-    transformationSequence.emplace_back(
-        new UnrollVector<VectorOpType>(targetShape));
-    return *this;
-  }
-  /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
-  /// operation.
-  template <typename VectorOpType>
-  MatmulCodegenStrategy &unrollVectorIf(bool b, ArrayRef<int64_t> targetShape) {
-    return b ? unrollVector<VectorOpType>(targetShape) : *this;
-    return *this;
-  }
-
-  // Enable default lowering strategy for CPU.
-  MatmulCodegenStrategy &setDefaultCPULowering();
-
- private:
-  LogicalResult postPatternTransforms(Operation *func) const;
-
-  std::function<void(FuncOp)> lowering = nullptr;
-  bool hoistInvariantCode = false;
-  vector::VectorTransformsOptions vectorTransformsOptions;
-  VectorTransferToSCFOptions vectorToSCFOptions;
-  SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
-};
-
-/// Perform folding of chains of AffineMinOp.
-struct AffineMinCanonicalizationPattern
-    : public mlir::OpRewritePattern<mlir::AffineMinOp> {
-  using OpRewritePattern<mlir::AffineMinOp>::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::AffineMinOp minOp, mlir::PatternRewriter &rewriter) const override;
-};
-}  // namespace mlir
-
-#endif  // IREE_COMPILER_CONVERSION_CODEGENUTILS_MATMULCODEGENSTRATEGY_H_
diff --git a/iree/compiler/Conversion/CodegenUtils/TransformUtils.cpp b/iree/compiler/Conversion/CodegenUtils/TransformUtils.cpp
new file mode 100644
index 0000000..6aaca75
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/TransformUtils.cpp
@@ -0,0 +1,166 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// -----------------------------------------------------------------------------
+// This code will be removed once this gets upstreamed to common mlir.
+// Please try to limit changes in this code only minor changes.
+
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;          // NOLINT
+using namespace mlir::linalg;  // NOLINT
+
+#define DEBUG_TYPE "linalg-transform-utils"
+
+//===----------------------------------------------------------------------===//
+// TODO: Cleanup and upstream these to go into core. Please ignore for now !
+//===----------------------------------------------------------------------===//
+/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
+///   `%lb + %step * new_dim` where
+/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
+/// AffineDimExpr depending on whether the value is constant or not.
+/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
+/// AffineSymbolExpr depending on whether the value is constant or not.
+///
+static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
+                       SmallVectorImpl<Value> &dims,
+                       SmallVectorImpl<Value> &symbols) {
+  MLIRContext *ctx = forOp.getContext();
+  auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
+  AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
+                             : getAffineDimExpr(dims.size(), ctx);
+
+  auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
+  AffineExpr step = stepConstant
+                        ? getAffineConstantExpr(stepConstant.getValue(), ctx)
+                        : getAffineSymbolExpr(symbols.size(), ctx);
+
+  if (!lbConstant) dims.push_back(forOp.lowerBound());
+  if (!stepConstant) symbols.push_back(forOp.step());
+  exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
+
+  auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
+  AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
+                             : getAffineDimExpr(dims.size(), ctx);
+  if (!ubConstant) dims.push_back(forOp.upperBound());
+  exprs.push_back(ub);
+
+  dims.push_back(forOp.getInductionVar());
+}
+
+/// Substitue dimensions coming from forOp or AffineMin. Return false if it has
+/// unknown dimension operands.
+static bool substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
+                       SmallVectorImpl<Value> &dims,
+                       SmallVectorImpl<Value> &symbols) {
+  if (minOp.getDimOperands().empty()) return false;
+  for (Value v : minOp.getDimOperands()) {
+    if (auto forOp = scf::getForInductionVarOwner(v)) {
+      substitute(forOp, exprs, dims, symbols);
+      continue;
+    }
+    if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
+      substitute(parentMinOp, exprs, dims, symbols);
+      continue;
+    }
+    // If couldn't substitue the dimension give up and use the original map.
+    return false;
+  }
+  return true;
+}
+
+LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
+    AffineMinOp minOp, PatternRewriter &rewriter) const {
+  LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
+                          << *minOp.getOperation() << "\n");
+
+  int64_t min = std::numeric_limits<int64_t>::max();
+  for (auto e : minOp.map().getResults())
+    if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
+      min = std::min(min, cstExpr.getValue());
+  if (min == std::numeric_limits<int64_t>::max()) return failure();
+
+  MLIRContext *ctx = minOp.getContext();
+  AffineMap map;
+  SmallVector<Value, 4> operands;
+  SmallVector<AffineExpr, 4> exprs;
+  SmallVector<Value, 4> dims, symbols;
+  if (substitute(minOp, exprs, dims, symbols)) {
+    operands = dims;
+    operands.append(symbols.begin(), symbols.end());
+
+    map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
+    LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
+  } else {
+    map = minOp.getAffineMap();
+    operands = minOp.getDimOperands();
+    operands.append(minOp.getSymbolOperands().begin(),
+                    minOp.getSymbolOperands().end());
+  }
+  SmallVector<AffineExpr, 4> modExprs;
+  for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
+    modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
+  map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
+  canonicalizeMapAndOperands(&map, &operands);
+  map = simplifyAffineMap(map);
+
+  LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
+             llvm::interleaveComma(operands, llvm::dbgs()));
+
+  if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
+        if (auto cst = e.dyn_cast<AffineConstantExpr>())
+          return cst.getValue() == 0;
+        return false;
+      }))
+    return failure();
+
+  rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
+  return success();
+}
+//===----------------------------------------------------------------------===//
+// END TODO
+//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/CodegenUtils/TransformUtils.h b/iree/compiler/Conversion/CodegenUtils/TransformUtils.h
new file mode 100644
index 0000000..5648fcb
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/TransformUtils.h
@@ -0,0 +1,39 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_TRANSFORMUTILS_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_TRANSFORMUTILS_H_
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+/// Perform folding of chains of AffineMinOp.
+struct AffineMinCanonicalizationPattern
+    : public mlir::OpRewritePattern<mlir::AffineMinOp> {
+  using OpRewritePattern<mlir::AffineMinOp>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::AffineMinOp minOp, mlir::PatternRewriter &rewriter) const override;
+};
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CONVERSION_CODEGENUTILS_TRANSFORMUTILS_H_
diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp
index f82b904..d5f3573 100644
--- a/iree/compiler/Conversion/Common/Transforms.cpp
+++ b/iree/compiler/Conversion/Common/Transforms.cpp
@@ -23,11 +23,12 @@
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
 #include "iree/compiler/Conversion/Common/Attributes.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index 406abdc..e1f880c 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -42,7 +42,6 @@
         "@llvm-project//mlir:Transforms",
         "@mlir-hlo//:hlo",
         "@mlir-hlo//:legalize_to_linalg",
-        "@mlir-hlo//:map_lmhlo_to_scalar_op",
     ],
 )
 
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index f0b15c2..e00182a 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -33,7 +33,6 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -636,11 +635,9 @@
                       ArrayRef<Value> resultBuffers,
                       ConversionPatternRewriter &rewriter) const {
     auto loc = op.getLoc();
-    auto subViewOp = rewriter.create<SubViewOp>(
-        loc, inputBuffers[0], extractFromI64ArrayAttr(op.static_offsets()),
-        extractFromI64ArrayAttr(op.static_sizes()),
-        extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
-        op.strides());
+    auto subViewOp =
+        rewriter.create<SubViewOp>(loc, inputBuffers[0], op.getMixedOffsets(),
+                                   op.getMixedSizes(), op.getMixedStrides());
     rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);
     return success();
   }
@@ -749,236 +746,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// mhlo.reduce conversion patterns and utility functions.
-//===----------------------------------------------------------------------===//
-
-/// Returns a permutation AffineMap that puts all reduction dimensions to the
-/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
-/// if `rank` is 4 and `reductionDims` is {1, 3}, then
-/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
-/// the AffineMap is returned.
-static AffineMap getTransposeMapForReduction(MLIRContext *context, int rank,
-                                             ArrayRef<int> reductionDims) {
-  llvm::SmallSetVector<int, 4> s;
-  for (auto dim : reductionDims) s.insert(dim);
-
-  SmallVector<unsigned, 4> permutation;
-  for (int i = 0; i < rank; ++i) {
-    if (!s.count(i)) permutation.push_back(i);
-  }
-  for (auto dim : reductionDims) permutation.push_back(dim);
-
-  auto map = AffineMap::getPermutationMap(permutation, context);
-  return inversePermutation(map);
-}
-
-/// Checks whether an op is wthin an xla-hlo reduce region. During conversion,
-/// the body of the reduce gets moved into a linalg.indexed_generic op. So check
-/// if the op is within a linalg.indexed_generic op.
-static bool isWithinReduceOpRegion(Operation *op) {
-  return isa<linalg::IndexedGenericOp>(op->getParentOp());
-}
-
-namespace {
-
-/// Type converter for converting the region of an mhlo::reduce op.
-class ReduceRegionTypeConverter : public TypeConverter {
- public:
-  Type convertType(Type type) const {
-    if (type.isSignlessIntOrFloat()) {
-      return type;
-    } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
-      if (tensorType.getRank() == 0) return tensorType.getElementType();
-    }
-    return nullptr;
-  }
-};
-
-/// Converts the mhlo.reduce op on tensors to a linalg.indexed_generic op on
-/// buffers. Expects that the reduce op is the only op within the dispatch
-/// function. This pattern also fuses std.constant operations which are defining
-/// ops of the init value with the linalg.indexed_generic op.
-struct ReduceOpConversion
-    : public ConvertToLinalgBufferOp<ReduceOpConversion, mhlo::ReduceOp> {
-  using ConvertToLinalgBufferOp<ReduceOpConversion,
-                                mhlo::ReduceOp>::ConvertToLinalgBufferOp;
-  LogicalResult apply(mhlo::ReduceOp reduceOp, ArrayRef<Value> inputBuffers,
-                      ArrayRef<Value> resultBuffers,
-                      ConversionPatternRewriter &rewriter) const;
-
- private:
-  ReduceRegionTypeConverter converter;
-};
-
-/// Base class for converting operations within the reduction op region. Derived
-/// classes implement the following static method to implement the conversion.
-///
-///   static Operation *apply(OpTy op, ArrayRef<Value> operands,
-///                           ConversionPatternRewriter &rewriter);
-template <typename DerivedTy, typename OpTy>
-struct ReduceRegionOpConversion : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      OpTy op, ArrayRef<Value> operands,
-      ConversionPatternRewriter &rewriter) const override {
-    // Only convert it if it is within a reduce op region.
-    if (!isWithinReduceOpRegion(op)) return failure();
-    Operation *replacement = DerivedTy::apply(op, operands, rewriter);
-    if (!replacement) return failure();
-    rewriter.replaceOp(op, replacement->getResults());
-    return success();
-  }
-
- protected:
-  ReduceRegionTypeConverter converter;
-};
-
-/// Converts XLA ops within reduce region to standard ops.
-template <typename OpTy>
-struct ReduceRegionXLAOpConversion final
-    : public ReduceRegionOpConversion<ReduceRegionXLAOpConversion<OpTy>, OpTy> {
-  using ReduceRegionOpConversion<ReduceRegionXLAOpConversion<OpTy>,
-                                 OpTy>::ReduceRegionOpConversion;
-  static Operation *apply(OpTy op, ArrayRef<Value> operands,
-                          ConversionPatternRewriter &rewriter) {
-    Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
-        op, operands[0].getType(), operands, &rewriter);
-    return result.getDefiningOp();
-  }
-};
-
-/// Converts mhlo.return to within a reduce region to a linalg.yield.
-struct ReduceRegionReturnOpConversion final
-    : public ReduceRegionOpConversion<ReduceRegionReturnOpConversion,
-                                      mhlo::ReturnOp> {
-  using ReduceRegionOpConversion<ReduceRegionReturnOpConversion,
-                                 mhlo::ReturnOp>::ReduceRegionOpConversion;
-  static Operation *apply(mhlo::ReturnOp op, ArrayRef<Value> operands,
-                          ConversionPatternRewriter &rewriter) {
-    return rewriter.create<linalg::YieldOp>(op.getLoc(), operands[0]);
-  }
-};
-}  // namespace
-
-LogicalResult ReduceOpConversion::apply(
-    mhlo::ReduceOp reduceOp, ArrayRef<Value> inputBuffers,
-    ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
-  if (reduceOp.getNumOperands() != 2) return failure();
-  Value src = *reduceOp.operands().begin();
-  Value initVal = *reduceOp.init_values().begin();
-  if (reduceOp.getNumResults() != 1) return failure();
-
-  auto srcArgType = src.getType().template cast<ShapedType>();
-  unsigned nInputRank = srcArgType.getRank();
-  if (!nInputRank) return failure();
-
-  // Get the reduction dimension. For now expects only a single reduction
-  // dimension.
-  auto loc = reduceOp.getLoc();
-  DenseIntElementsAttr dimensionsAttr = reduceOp.dimensions();
-  SmallVector<int, 4> reductionDims;
-  for (const auto &dim : dimensionsAttr.getIntValues()) {
-    reductionDims.push_back(dim.getSExtValue());
-  }
-
-  // Check if initVal is constant. If so, inline the value into the region.
-  Attribute initConstVal = getInitValueAsConst(initVal);
-  if (initConstVal) {
-    if (initVal.hasOneUse()) rewriter.eraseOp(initVal.getDefiningOp());
-    initVal = rewriter.create<ConstantOp>(initVal.getDefiningOp()->getLoc(),
-                                          initConstVal);
-  }
-
-  // Prepare indexing maps for linalg generic op. The elements are for src,
-  // initial value and dst, respectively.
-  // Transpose `src` to make the reduction loops be the innermost, because it's
-  // easier to fully utilize processors.
-  SmallVector<AffineMap, 3> indexingMaps;
-  indexingMaps.emplace_back(getTransposeMapForReduction(
-      rewriter.getContext(), nInputRank, reductionDims));
-  if (!initConstVal) {
-    indexingMaps.emplace_back(
-        AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext()));
-  }
-  // The indexing map of `dst` should drop the reduction loops. Since the
-  // reduction loops now are all in the innermost, drops `reductionDims.size()`
-  // dimensions. We don't need an inverse permutation here because they are the
-  // same.
-  SmallVector<AffineExpr, 4> exprs;
-  for (int i = 0, e = nInputRank - reductionDims.size(); i < e; ++i) {
-    exprs.push_back(rewriter.getAffineDimExpr(i));
-  }
-  indexingMaps.emplace_back(
-      exprs.empty()
-          ? AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext())
-          : AffineMap::get(nInputRank, /*symbolCount=*/0, exprs,
-                           rewriter.getContext()));
-
-  SmallVector<Type, 2> resultTypes = {};
-  SmallVector<Value, 2> inputs = {inputBuffers[0]};
-  if (!initConstVal) {
-    inputs.push_back(inputBuffers[1]);
-  }
-  if (failed(zeroFillBuffer(loc, resultBuffers[0], rewriter))) {
-    rewriter.notifyMatchFailure(reduceOp, "failed to zero fill result buffer");
-    return failure();
-  }
-  auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
-      loc, /*resultTensorTypes=*/resultTypes, /*inputs=*/inputs,
-      /*outputBuffers=*/resultBuffers, indexingMaps,
-      getParallelAndReductionIterators(nInputRank, reductionDims.size()));
-
-  rewriter.inlineRegionBefore(reduceOp.body(), linalgOp.region(),
-                              linalgOp.region().end());
-  {
-    OpBuilder::InsertionGuard regionGuard(rewriter);
-
-    // Convert the signature of the body. The reduce op region apply function
-    // has a signature (lhs, rhs) -> output, all of the same tensor type t. This
-    // is converted to a function with the same signature but with element
-    // types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will be
-    // converted to "(f32, f32, f32)".
-    TypeConverter::SignatureConversion signatureConverter(2);
-    Type argType = linalgOp.region().front().getArgument(0).getType();
-    Type convertedType = converter.convertType(argType);
-    Type indexType = rewriter.getIndexType();
-    for (unsigned i = 0; i < nInputRank; ++i) {
-      signatureConverter.addInputs(indexType);
-    }
-    signatureConverter.addInputs(0, convertedType);
-    if (!initConstVal) signatureConverter.addInputs(convertedType);
-    signatureConverter.addInputs(1, convertedType);
-    Block *entryBlock = rewriter.applySignatureConversion(&linalgOp.region(),
-                                                          signatureConverter);
-
-    // The indexed generic op generated here combines the input value with the
-    // init value for the zero-th iteration of the reduction loop. This is
-    // yielded by the region to model a store of the value to the output. The
-    // input value with the output value for all other iterations.
-    unsigned numArgs = entryBlock->getNumArguments();
-    BlockArgument blockDstArg = entryBlock->getArgument(numArgs - 1);
-    rewriter.setInsertionPointToStart(entryBlock);
-    Value initArg =
-        initConstVal ? initVal : entryBlock->getArgument(numArgs - 2);
-    // The reduction dimensions are the innermost loops now, compare all
-    // reduction indices to zero. If they are all zero, it's the first time to
-    // update the output element, i.e., we should take initial value to compute
-    // with the input element.
-    Value zero = rewriter.create<ConstantOp>(
-        loc, indexType, rewriter.getIntegerAttr(indexType, 0));
-    Value cond = rewriter.create<ConstantOp>(loc, rewriter.getBoolAttr(true));
-    for (int i = nInputRank - reductionDims.size(); i < nInputRank; ++i) {
-      Value isZero = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq,
-                                             entryBlock->getArgument(i), zero);
-      cond = rewriter.create<AndOp>(loc, cond, isZero);
-    }
-    Value lhs = rewriter.create<SelectOp>(loc, cond, initArg, blockDstArg);
-    rewriter.replaceUsesOfBlockArgument(blockDstArg, lhs);
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
 // Linalg op on tensors to linalg op on buffers conversion base class.
 //===----------------------------------------------------------------------===//
 
@@ -1472,14 +1239,9 @@
                   DynamicTensorFromElementsOpConversion, InitTensorOpConversion,
                   LinalgOpOnTensorConversion<linalg::GenericOp>,
                   LinalgOpOnTensorConversion<linalg::IndexedGenericOp>,
-                  PadOpConversion, ReduceOpConversion, ReduceWindowOpConversion,
+                  PadOpConversion, ReduceWindowOpConversion,
                   SubTensorOpConversion, TensorReshapeOpConversion>(
       context, resultTensorToBufferMap);
-  // Reduce region operation conversions.
-  patterns.insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
-                  ReduceRegionXLAOpConversion<mhlo::MinOp>,
-                  ReduceRegionXLAOpConversion<mhlo::MaxOp>,
-                  ReduceRegionReturnOpConversion>(context);
 }
 
 void ConvertHLOToLinalgOnBuffersPass::runOnFunction() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index 0d4f162..af7797b 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -175,9 +175,13 @@
       staticStrides.push_back(stride);
     }
     rewriter.replaceOpWithNewOp<SubTensorOp>(
-        op, args[0], staticOffsets, staticSizes, staticStrides,
+        op, op.getType(), args[0],
         /*offsets=*/ValueRange{},
-        /*sizes=*/ValueRange{}, /*strides=*/ValueRange{});
+        /*sizes=*/ValueRange{},
+        /*strides=*/ValueRange{}, rewriter.getI64ArrayAttr(staticOffsets),
+        rewriter.getI64ArrayAttr(staticSizes),
+        rewriter.getI64ArrayAttr(staticStrides));
+
     return success();
   }
 };
@@ -199,8 +203,7 @@
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
             [](Operation *op) {
               auto parentOp = op->getParentRegion()->getParentOp();
-              return isa<mhlo::ReduceOp>(parentOp) ||
-                     isa<mhlo::ReduceWindowOp>(parentOp);
+              return isa<mhlo::ReduceWindowOp>(parentOp);
             }));
     // Let the rest fall through.
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/reduce.mlir b/iree/compiler/Conversion/HLOToLinalg/test/reduce.mlir
index f7882f5..361758e 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/reduce.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/reduce.mlir
@@ -1,25 +1,25 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline %s | IreeFileCheck %s
 
 // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 module {
-  //      CHECK: func @reduction_entry
+  //      CHECK: func @reduce_add
   //  CHECK-DAG: %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<5xf32>
   //  CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<5x4xf32>
   //  CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-  //      CHECK: linalg.indexed_generic
+  //      CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
+  //      CHECK: linalg.fill(%[[ARG2]], %[[INIT]])
+  //      CHECK: linalg.generic
   // CHECK-SAME: indexing_maps
-  // CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]]
+  // CHECK-SAME: #[[MAP0]], #[[MAP1]]
   // CHECK-SAME: iterator_types = ["parallel", "reduction"]}
-  // CHECK-SAME:   ins(%[[ARG0]], %[[ARG1]] : memref<5x4xf32>, memref<f32>
+  // CHECK-SAME:   ins(%[[ARG0]] : memref<5x4xf32>
   // CHECK-SAME:   outs(%[[ARG2]] : memref<5xf32>
-  // CHECK-NEXT: ^{{.+}}(%{{.+}}, %[[IDX:.+]]: index, %[[SRC:.+]]: f32, %[[INIT:.+]]: f32, %[[DST:.+]]: f32):
-  //      CHECK:   %[[OPERAND:.+]] = select %{{.+}}, %[[INIT]], %[[DST]] : f32
-  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[OPERAND]] : f32
+  // CHECK-NEXT: ^{{.+}}(%[[SRC:.+]]: f32, %[[DST:.+]]: f32):
+  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[DST]] : f32
   // CHECK-NEXT:   linalg.yield %[[RES]] : f32
   // CHECK-NEXT: }
-  func @reduction_entry() {
+  func @reduce_add() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -43,7 +43,7 @@
 module {
   //      CHECK:   %[[COND:.+]] = cmpf olt, %{{.+}}, %{{.+}} : f32
   // CHECK-NEXT:   select %[[COND]], %{{.+}}, %{{.+}} : f32
-  func @reduction_entry() {
+  func @reduce_minimum() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -67,7 +67,7 @@
 module {
   //      CHECK:   %[[COND:.+]] = cmpf ogt, %{{.+}}, %{{.+}} : f32
   // CHECK-NEXT:   select %[[COND]], %{{.+}}, %{{.+}} : f32
-  func @reduction_entry() {
+  func @reduce_maximum() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -88,50 +88,26 @@
 
 // -----
 
-module {
-  //      CHECK:   %[[COND:.+]] = cmpf ogt, %{{.+}}, %{{.+}} : f32
-  // CHECK-NEXT:   select %[[COND]], %{{.+}}, %{{.+}} : f32
-  func @reduction_entry() {
-    %c0 = constant 0 : index
-    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4xf32>
-    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
-    %2 = "mhlo.reduce"(%0, %1) ({
-    ^bb0(%arg3: tensor<f32>, %arg4 : tensor<f32>):
-      %3 = mhlo.maximum %arg3, %arg4 : tensor<f32>
-      "mhlo.return"(%3) : (tensor<f32>) -> ()
-    }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xf32>, tensor<f32>) -> tensor<4xf32>
-    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<4xf32>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
-  }
-}
-
-// -----
-
 // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 module {
-  //      CHECK: func @reduction_entry
-  //      CHECK: %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
-  //      CHECK: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<5x4xf32>
-  //      CHECK: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-  //      CHECK: linalg.indexed_generic {
+  //      CHECK: func @reduce_dim0
+  //  CHECK-DAG: %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
+  //  CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<5x4xf32>
+  //  CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
+  //      CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
+  //      CHECK: linalg.fill(%[[ARG2]], %[[INIT]])
+  //      CHECK: linalg.generic
   // CHECK-SAME: indexing_maps
-  // CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]]
+  // CHECK-SAME: #[[MAP0]], #[[MAP1]]
   // CHECK-SAME: iterator_types = ["parallel", "reduction"]}
-  // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<5x4xf32>, memref<f32>)
-  // CHECK-SAME: outs(%[[ARG2]] : memref<4xf32>)
-  // CHECK-NEXT: ^{{.+}}(%{{.+}}, %[[IDX:.+]]: index, %[[SRC:.+]]: f32, %[[INIT:.+]]: f32, %[[DST:.+]]: f32):
-  //      CHECK:   %[[OPERAND:.+]] = select %{{.+}}, %[[INIT]], %[[DST]] : f32
-  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[OPERAND]] : f32
+  // CHECK-SAME:   ins(%[[ARG0]] : memref<5x4xf32>
+  // CHECK-SAME:   outs(%[[ARG2]] : memref<4xf32>
+  // CHECK-NEXT: ^{{.+}}(%[[SRC:.+]]: f32, %[[DST:.+]]: f32):
+  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[DST]] : f32
   // CHECK-NEXT:   linalg.yield %[[RES]] : f32
   // CHECK-NEXT: }
-  func @reduction_entry() {
+  func @reduce_dim0() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -152,22 +128,34 @@
 
 // -----
 
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 module {
-  // CHECK-LABEL: func @reduce_init_const
+  //      CHECK: func @reduce_init_const
+  //  CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2xf32>
+  //  CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x10xf32>
+  //  CHECK-DAG: %[[CST:.+]] = constant 0xFF800000 : f32
+  //      CHECK: linalg.fill(%[[OUT]], %[[CST]])
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps
+  // CHECK-SAME: #[[MAP0]], #[[MAP1]]
+  // CHECK-SAME: iterator_types = ["parallel", "reduction"]}
+  // CHECK-SAME:   ins(%[[ARG0]] : memref<2x10xf32>
+  // CHECK-SAME:   outs(%[[ARG2]] : memref<2xf32>
+  // CHECK-NEXT: ^{{.+}}(%[[SRC:.+]]: f32, %[[DST:.+]]: f32):
+  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[DST]] : f32
+  // CHECK-NEXT:   linalg.yield %[[RES]] : f32
+  // CHECK-NEXT: }
   func @reduce_init_const() {
     %c0 = constant 0 : index
-    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x10xf32>
-    // CHECK: %[[CST:.+]] = constant 0xFF800000 : f32
-    // CHECK: linalg.indexed_generic
-    // CHECK: ^{{.+}}(%{{.+}}: index, %[[DIM:.+]]: index, %{{.+}}: f32, %[[OUTPUT:.+]]: f32):
-    // CHECK: select %{{.+}}, %[[CST]], %[[OUTPUT]] : f32
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x10xf32>
     %cst = constant dense<0xFF800000> : tensor<f32>
     %1 = "mhlo.reduce"(%0, %cst) ({
     ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
       %2 = mhlo.add %arg2, %arg3 {name = "maximum.21"} : tensor<f32>
       "mhlo.return"(%2) : (tensor<f32>) -> ()
-    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
-    hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<1xf32>
+    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x10xf32>, tensor<f32>) -> tensor<2xf32>
+    hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<2xf32>
     return
   }
   hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -179,30 +167,25 @@
 // -----
 
 // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> ()>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0)>
 module {
-  //      CHECK: func @reduction_multi_dimensions
+  //      CHECK: func @reduce_multi_dimensions
   //      CHECK: %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
   //      CHECK: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<5x4x3xf32>
   //      CHECK: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-  //      CHECK: linalg.indexed_generic {
+  //      CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
+  //      CHECK: linalg.fill(%[[ARG2]], %[[INIT]])
+  //      CHECK: linalg.generic {
   // CHECK-SAME: indexing_maps
-  // CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]]
+  // CHECK-SAME: #[[MAP0]], #[[MAP1]]
   // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]}
-  // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<5x4x3xf32>, memref<f32>)
+  // CHECK-SAME: ins(%[[ARG0]] : memref<5x4x3xf32>)
   // CHECK-SAME: outs(%[[ARG2]] : memref<4xf32>)
-  // CHECK-NEXT: ^{{.+}}(%{{.+}}, %[[IDX:.+]]: index, %[[SRC:.+]]: f32, %[[INIT:.+]]: f32, %[[DST:.+]]: f32):
-  //      CHECK:   %[[TRUE:.+]] = constant true
-  //      CHECK:   %[[CMP1:.+]] = cmpi
-  //      CHECK:   %[[COND1:.+]] = and %[[TRUE]], %[[CMP1]]
-  //      CHECK:   %[[CMP2:.+]] = cmpi
-  //      CHECK:   %[[COND2:.+]] = and %[[COND1]], %[[CMP2]]
-  // CHECK-NEXT:   %[[OPERAND:.+]] = select %[[COND2]], %[[INIT]], %[[DST]] : f32
-  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[OPERAND]] : f32
+  // CHECK-NEXT: ^{{.+}}(%[[SRC:.+]]: f32, %[[DST:.+]]: f32):
+  // CHECK-NEXT:   %[[RES:.+]] = addf %[[SRC]], %[[DST]] : f32
   // CHECK-NEXT:   linalg.yield %[[RES]] : f32
   // CHECK-NEXT: }
-  func @reduction_multi_dimensions() {
+  func @reduce_multi_dimensions() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x4x3xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgBufferizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgBufferizePass.cpp
index b8c02ce..59c272e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgBufferizePass.cpp
@@ -343,11 +343,8 @@
                                   loadOp.queryBindingOp(), /*typeErase=*/true);
   Value buffer = phOp.getResult();
   Value subview =
-      b.create<SubViewOp>(loadOp->getLoc(), buffer,
-                          extractFromI64ArrayAttr(loadOp.static_offsets()),
-                          extractFromI64ArrayAttr(loadOp.static_sizes()),
-                          extractFromI64ArrayAttr(loadOp.static_strides()),
-                          loadOp.offsets(), loadOp.sizes(), loadOp.strides());
+      b.create<SubViewOp>(loadOp->getLoc(), buffer, loadOp.getMixedOffsets(),
+                          loadOp.getMixedSizes(), loadOp.getMixedStrides());
   bvm.map(loadOp.result(), subview);
   // TODO(nicolasvasilache): kill tie_shape with fire.
   mapAllTieShapeUsesAndReplaceDimUsesOf(loadOp.result(), subview, bvm);
@@ -417,12 +414,9 @@
       createPlaceholderOp(b, storeOp.getLoc(), storeOp, storeOp.operand(),
                           storeOp.queryBindingOp(), /*typeErase=*/true);
   Value buffer = phOp.getResult();
-  Value subview = b.create<SubViewOp>(
-      storeOp->getLoc(), buffer,
-      extractFromI64ArrayAttr(storeOp.static_offsets()),
-      extractFromI64ArrayAttr(storeOp.static_sizes()),
-      extractFromI64ArrayAttr(storeOp.static_strides()), storeOp.offsets(),
-      storeOp.sizes(), storeOp.strides());
+  Value subview =
+      b.create<SubViewOp>(storeOp->getLoc(), buffer, storeOp.getMixedOffsets(),
+                          storeOp.getMixedSizes(), storeOp.getMixedStrides());
   b.create<linalg::CopyOp>(storeOp->getLoc(),
                            iterativeLookup(bvm, storeOp.operand()), subview);
   storeOp->erase();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
index d156538..fa52276 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
@@ -13,9 +13,10 @@
 // limitations under the License.
 
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
@@ -50,7 +51,7 @@
     : public linalg::LinalgBaseTilingPattern {
   using Base = linalg::LinalgBaseTilingPattern;
   TileAndDistributeOnTensorsPattern(linalg::LinalgTilingOptions options,
-                                    linalg::LinalgMarker marker,
+                                    linalg::LinalgTransformationFilter marker,
                                     PatternBenefit benefit = 1)
       : Base(options, marker, benefit) {}
 
@@ -108,8 +109,9 @@
     // SPMD loops.
     patterns.insert<TileAndDistributeOnTensorsPattern>(
         linalgTilingOptions,
-        linalg::LinalgMarker(ArrayRef<Identifier>(),
-                             Identifier::get(getWorkgroupMarker(), context)));
+        linalg::LinalgTransformationFilter(
+            ArrayRef<Identifier>(),
+            Identifier::get(getWorkgroupMarker(), context)));
     // Add canonicalization patterns.
     linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context);
     patterns.insert<AffineMinCanonicalizationPattern>(context);
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
index 547ece7..cd1f2b5 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -15,12 +15,12 @@
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
 #include "iree/compiler/Conversion/Common/Attributes.h"
 #include "iree/compiler/Conversion/Common/Transforms.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp
index 98777c1..6e6a540 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp
@@ -13,9 +13,10 @@
 // limitations under the License.
 
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -31,7 +32,8 @@
 struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
   using Base = linalg::LinalgBaseTilingPattern;
   TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options,
-                 linalg::LinalgMarker marker, PatternBenefit benefit = 1)
+                 linalg::LinalgTransformationFilter marker,
+                 PatternBenefit benefit = 1)
       : Base(LinalgOpTy::getOperationName(), context, options, marker,
              benefit) {}
   LogicalResult matchAndRewrite(Operation *op,
@@ -77,7 +79,7 @@
               return TileSizeFn::get<TilingLevel::Level1Tiles>(
                   cpuKernelDispatch, builder, operation);
             }),
-        linalg::LinalgMarker(
+        linalg::LinalgTransformationFilter(
             Identifier::get(getWorkgroupMarker(), context),
             Identifier::get(getWorkgroupL1TileMarker(), context)));
 
@@ -96,7 +98,7 @@
               return TileSizeFn::get<TilingLevel::Level2Tiles>(
                   cpuKernelDispatch, builder, operation);
             }),
-        linalg::LinalgMarker(
+        linalg::LinalgTransformationFilter(
             Identifier::get(getWorkgroupL1TileMarker(), context),
             Identifier::get(getVectorizeMarker(), context)));
 
@@ -120,8 +122,9 @@
     vectorizationPatterns
         .insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>,
                 linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>>(
-            context, linalg::LinalgMarker(
-                         Identifier::get(getVectorizeMarker(), context)));
+            context, linalg::LinalgVectorizationOptions(),
+            linalg::LinalgTransformationFilter(
+                Identifier::get(getVectorizeMarker(), context)));
     applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
   }
 
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp
index 00295e7..ea17d54 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp
@@ -62,8 +62,9 @@
         .insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>,
                 linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>,
                 linalg::LinalgVectorizationPattern<linalg::GenericOp>>(
-            context, linalg::LinalgMarker(ArrayRef<Identifier>(
-                         Identifier::get(getWorkgroupMarker(), context))));
+            context, linalg::LinalgVectorizationOptions(),
+            linalg::LinalgTransformationFilter(ArrayRef<Identifier>(
+                Identifier::get(getWorkgroupMarker(), context))));
     applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
 
     LLVM_DEBUG({
diff --git a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
index db898a4..5154426 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
@@ -38,10 +38,10 @@
   auto context = funcOp.getContext();
 
   auto marker = Identifier::get("generalized_from_conv", context);
-  linalg::LinalgMarker firstStepMarker(
+  linalg::LinalgTransformationFilter firstStepMarker(
       /*matchDisjunction=*/ArrayRef<Identifier>(),
       /*replacement=*/marker);
-  linalg::LinalgMarker secondStepMarker(
+  linalg::LinalgTransformationFilter secondStepMarker(
       /*matchDisjunction=*/marker,
       /*replacement=*/llvm::None);
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 0ee0b0f..735b91f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -57,15 +57,15 @@
 //===----------------------------------------------------------------------===//
 
 /// Returns a Linalg marker that replaces existing markers.
-linalg::LinalgMarker getLinalgReplaceMarker(StringRef maker,
-                                            MLIRContext *context) {
-  return linalg::LinalgMarker(ArrayRef<Identifier>(),
-                              Identifier::get(maker, context));
+linalg::LinalgTransformationFilter getLinalgReplaceMarker(
+    StringRef maker, MLIRContext *context) {
+  return linalg::LinalgTransformationFilter(ArrayRef<Identifier>(),
+                                            Identifier::get(maker, context));
 }
 
 /// Returns a Linalg marker that matches any of the `matchMarkers` and replaces
 /// it with `replaceMarker`.
-linalg::LinalgMarker getLinalgMatchAndReplaceMarker(
+linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
     ArrayRef<StringRef> matchMarkers, StringRef replaceMarker,
     MLIRContext *context) {
   SmallVector<Identifier, 2> markers;
@@ -73,7 +73,8 @@
   for (StringRef marker : matchMarkers) {
     markers.emplace_back(Identifier::get(marker, context));
   }
-  return linalg::LinalgMarker(markers, Identifier::get(replaceMarker, context));
+  return linalg::LinalgTransformationFilter(
+      markers, Identifier::get(replaceMarker, context));
 }
 
 /// Returns the distribution options for operations when targeting workgroups.
@@ -138,7 +139,7 @@
     : public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
   PromoteMatmulSubviewsPattern(MLIRContext *context,
                                linalg::LinalgPromotionOptions options,
-                               linalg::LinalgMarker marker,
+                               linalg::LinalgTransformationFilter marker,
                                PatternBenefit benefit = 1)
       : linalg::LinalgPromotionPattern<linalg::MatmulOp>(
             context,
@@ -163,7 +164,7 @@
     : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
   PromoteConvSubviewsPattern(MLIRContext *context,
                              linalg::LinalgPromotionOptions options,
-                             linalg::LinalgMarker marker,
+                             linalg::LinalgTransformationFilter marker,
                              PatternBenefit benefit = 1)
       : linalg::LinalgPromotionPattern<linalg::ConvOp>(
             context,
@@ -216,7 +217,7 @@
   using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
   TileMatmulSubgroupPattern(MLIRContext *context,
                             linalg::LinalgTilingOptions options,
-                            linalg::LinalgMarker marker,
+                            linalg::LinalgTransformationFilter marker,
                             PatternBenefit benefit = 1)
       : Base(context, options, marker, benefit) {}
 };
@@ -329,8 +330,9 @@
                   linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>,
                   linalg::LinalgVectorizationPattern<linalg::FillOp>,
                   linalg::LinalgVectorizationPattern<linalg::GenericOp>>(
-      context,
-      linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
+      context, linalg::LinalgVectorizationOptions(),
+      linalg::LinalgTransformationFilter(
+          Identifier::get(getVectorizeMarker(), context)));
 }
 
 //====---------------------------------------------------------------------===//
@@ -387,10 +389,10 @@
 // Patterns to tile convolution window dimensions
 //====---------------------------------------------------------------------===//
 
-static void populateTilingConvFilterPatterns(MLIRContext *context,
-                                             OwningRewritePatternList &patterns,
-                                             const LaunchConfig &launchConfig,
-                                             linalg::LinalgMarker marker) {
+static void populateTilingConvFilterPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    const LaunchConfig &launchConfig,
+    linalg::LinalgTransformationFilter marker) {
   auto getTileSizeFn = [&launchConfig](OpBuilder &builder, Operation *op) {
     SmallVector<Value, 4> tileSizes;
     ArrayRef<int64_t> fourthLevel = launchConfig.getTileSizes(op, 3);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
index 9183933..2bd8f94 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
@@ -11,7 +11,7 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -46,7 +46,7 @@
   FuncOp fn = getFunction();
   SmallVector<uint32_t, 3> vUnrollSize(unrollSize.begin(), unrollSize.end());
   if (vUnrollSize.size() != 3) signalPassFailure();
-  MatmulCodegenStrategy strategy;
+  linalg::CodegenStrategy strategy;
   strategy
       .tile<linalg::MatmulOp>(
           linalg::LinalgTilingOptions()
@@ -55,9 +55,11 @@
               //.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
               .setTileSizes({wgTileSize, wgTileSize, wgTileSize}))
       .setHoistInvariantCode(enableLICM)
-      .vectorize<linalg::MatmulOp>()
-      .unrollVector<vector::ContractionOp>(
-          {vUnrollSize[0], vUnrollSize[1], vUnrollSize[2]});
+      .vectorize<linalg::LinalgOp>()
+      // TODO upstream to the core CodegenStrategy
+      // .unrollVector<vector::ContractionOp>(
+      //     {vUnrollSize[0], vUnrollSize[1], vUnrollSize[2]})
+      ;
   strategy.transform(fn);
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index b2d378b..8574774 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -22,13 +22,14 @@
 
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -195,8 +196,9 @@
   OwningRewritePatternList vectorizationPatterns;
   vectorizationPatterns
       .insert<linalg::LinalgVectorizationPattern<linalg::CopyOp>>(
-          context, linalg::LinalgMarker(
-                       Identifier::get(getVectorizeMarker(), context), {}));
+          context, linalg::LinalgVectorizationOptions(),
+          linalg::LinalgTransformationFilter(
+              Identifier::get(getVectorizeMarker(), context), {}));
   applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
 }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 0fe9904..d001fa8 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -234,7 +234,7 @@
     : public linalg::LinalgBaseTilingPattern {
   using Base = linalg::LinalgBaseTilingPattern;
   TileAndDistributeOnTensorsPattern(linalg::LinalgTilingOptions options,
-                                    linalg::LinalgMarker marker,
+                                    linalg::LinalgTransformationFilter marker,
                                     PatternBenefit benefit = 1)
       : Base(options, marker, benefit) {}
 
@@ -418,8 +418,8 @@
   patterns.insert<TileAndDistributeOnTensorsPattern>(
       linalgTilingOptions,
       // TODO(nicolavasilache): use refactored `getWorkgroupMarker()`
-      linalg::LinalgMarker(ArrayRef<Identifier>(),
-                           Identifier::get("workgroup", context)));
+      linalg::LinalgTransformationFilter(
+          ArrayRef<Identifier>(), Identifier::get("workgroup", context)));
 
   // Add canonicalization patterns.
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context);
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 1b0bcb1..beb1a72 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -2334,7 +2334,7 @@
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
-    std::array<unsigned, 3> getArrayAttrRanks() {
+    std::array<unsigned, 3> getArrayAttrMaxRanks() {
       unsigned resultRank = getResult().getType().cast<ShapedType>().getRank();
       return {resultRank, resultRank, resultRank};
     }
@@ -2412,7 +2412,7 @@
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
-    std::array<unsigned, 3> getArrayAttrRanks() {
+    std::array<unsigned, 3> getArrayAttrMaxRanks() {
       unsigned rank = operand().getType().cast<ShapedType>().getRank();
       return {rank, rank, rank};
     }
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 39e3ba5..77871f4 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 39e3ba5be1a67c1e1ed900a5bc8134d0c8374d73
+Subproject commit 77871f43e449ad492bf8b94dee453670ac15e158
diff --git a/third_party/llvm-project b/third_party/llvm-project
index f3f3c9c..b92a39a 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit f3f3c9c2549a268e602be8730990b552e30cc932
+Subproject commit b92a39ac1319c796777bca19a3af2856acbc69c1
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 471fc63..2b72ddc 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 471fc63c11205639dab25345aea1f85831ef4cb9
+Subproject commit 2b72ddc6b2b4d670bcd1ffa3f4652468b419f986
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 8f595c9..16613a7 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 8f595c955848e24b3edca51f78a68943e3c26f50
+Subproject commit 16613a70ef36b103e7c1ffa903d541814b62c109