Merge pull request #2351 from GMNGeoffrey:main-to-google

PiperOrigin-RevId: 318861045
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 9a89b24..137a304 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -3,7 +3,7 @@
 4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
 a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
 f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-1becd298b82ed2f1a8ba5e61c5ad2ce7fe32d812 third_party/llvm-project
+e34523c87c3f1cfabcf741568dede026bbb12d3a third_party/llvm-project
 17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
 67f3ccebee84f3488b46a8d3ac005178c52ff264 third_party/mlir-emitc
 80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
diff --git a/docs/GetStarted/getting_started_python.md b/docs/GetStarted/getting_started_python.md
index 1a0452b..3aec4e2 100644
--- a/docs/GetStarted/getting_started_python.md
+++ b/docs/GetStarted/getting_started_python.md
@@ -62,8 +62,7 @@
 
 See
 [start_colab_kernel.py](https://github.com/google/iree/blob/main/colab/start_colab_kernel.py)
-and
-[Using Colab](https://github.com/google/iree/blob/main/docs/using_colab.md)
+and [Using Colab](https://github.com/google/iree/blob/main/docs/using_colab.md)
 for setup instructions, then take a look through the
 [Colab directory](https://github.com/google/iree/tree/main/colab) for some
 sample notebooks.
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 7f6d34d..ad27616 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -128,6 +128,7 @@
         "//experimental/ModelBuilder:ModelRunner",
         "//experimental/ModelBuilder:VulkanLaunchWrapper",
         "//iree/base:initializer",
+        "//iree/compiler/Conversion/CodegenUtils",
         "//iree/compiler/Conversion/LinalgToSPIRV",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AllPassesAndDialects",
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 9c11fae..7b50f12 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -48,7 +48,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
 
 using namespace mlir;                    // NOLINT
 using namespace mlir::edsc;              // NOLINT
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 9164d9b..0a1e812 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -136,7 +136,7 @@
         np.float32)
     input_data = input_data.reshape(input_shape)
     self.modules.applications.all.predict(input_data).print().assert_all_close(
-        atol=1e-6)
+        atol=3e-5)
 
 
 if __name__ == '__main__':
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index a414146..3696239 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -23,12 +23,10 @@
     name = "CodegenUtils",
     srcs = [
         "FunctionUtils.cpp",
-        "MarkerUtils.cpp",
         "MatmulCodegenStrategy.cpp",
     ],
     hdrs = [
         "FunctionUtils.h",
-        "MarkerUtils.h",
         "MatmulCodegenStrategy.h",
     ],
     deps = [
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 08a7b14..08dfe7c 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -19,11 +19,9 @@
     CodegenUtils
   HDRS
     "FunctionUtils.h"
-    "MarkerUtils.h"
     "MatmulCodegenStrategy.h"
   SRCS
     "FunctionUtils.cpp"
-    "MarkerUtils.cpp"
     "MatmulCodegenStrategy.cpp"
   DEPS
     LLVMSupport
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 6a54529..01c0cba 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -22,7 +22,6 @@
 
 #include <cstddef>
 
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -612,8 +611,6 @@
       cond ? rewriter.create<SelectOp>(loc, cond, inputVal, paddingVal)
            : inputVal;
   rewriter.create<linalg::YieldOp>(loc, result);
-
-  setNoTileMarker(linalgOp);
   return success();
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index d2308b4..70df4d6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -23,14 +23,18 @@
         "ConvertToGPUPass.cpp",
         "ConvertToSPIRVPass.cpp",
         "LinalgTileAndFusePass.cpp",
+        "MarkerUtils.cpp",
         "Passes.cpp",
         "SplitDispatchFunctionPass.cpp",
+        "Utils.cpp",
         "VectorToGPUPass.cpp",
     ],
     hdrs = [
         "Attributes.h",
+        "MarkerUtils.h",
         "MemorySpace.h",
         "Passes.h",
+        "Utils.h",
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index ccc694c..b8821e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -19,14 +19,18 @@
     LinalgToSPIRV
   HDRS
     "Attributes.h"
+    "MarkerUtils.h"
     "MemorySpace.h"
     "Passes.h"
+    "Utils.h"
   SRCS
     "ConvertToGPUPass.cpp"
     "ConvertToSPIRVPass.cpp"
     "LinalgTileAndFusePass.cpp"
+    "MarkerUtils.cpp"
     "Passes.cpp"
     "SplitDispatchFunctionPass.cpp"
+    "Utils.cpp"
     "VectorToGPUPass.cpp"
   DEPS
     LLVMSupport
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 8a85b6a..9009441 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -17,8 +17,9 @@
 // Partition computation within dispatch function to workgroups/workitems.
 //
 //===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
@@ -36,11 +37,21 @@
 namespace mlir {
 namespace iree_compiler {
 
+// TODO(#2134): Remove this flag/set it to false always when issue with
+// convolution is resolved (see bug for more details).
+// TODO(#2346): Make this a pass specific option.
+llvm::cl::opt<bool> useLegacyConvLowering{
+    "iree-codegen-use-legacy-conv-lowering",
+    llvm::cl::desc("Use conv lowering that does not assume 1:1 mapping "
+                   "between threads within a block and iterations of "
+                   "parallel loops distributed to the block"),
+    llvm::cl::init(true)};
+
 //===----------------------------------------------------------------------===//
 // Loop utilities
 //===----------------------------------------------------------------------===//
 
-/// Builds an empty loop.for operation. The default builder adds an entry basic
+/// Builds an empty scf.for operation. The default builder adds an entry basic
 /// block which needs to be avoided here.
 static scf::ForOp buildEmptyForOp(Location loc, OpBuilder &builder, Value lb,
                                   Value ub, Value step) {
@@ -50,6 +61,15 @@
   return cast<scf::ForOp>(builder.createOperation(state));
 }
 
+/// Builds an empty scf.if operation without the then and else blocks.
+static scf::IfOp buildEmptyIfOp(Location loc, OpBuilder &builder, Value cond) {
+  OperationState state(loc, scf::IfOp::getOperationName());
+  state.addOperands(cond);
+  state.addRegion();
+  state.addRegion();
+  return cast<scf::IfOp>(builder.createOperation(state));
+}
+
 namespace {
 struct LoopBounds {
   Value lb;
@@ -58,10 +78,10 @@
 };
 }  // namespace
 
-/// Replaces a loop.parallelOp with an optional loop.parallel op and nested
-/// loop.for operations. To create the loop.parallel op as the outermost loop,
+/// Replaces a scf.parallelOp with an optional scf.parallel op and nested
+/// scf.for operations. To create the scf.parallel op as the outermost loop,
 /// pass the lower bound, upper bound and steps in `newPLoopLbs`, `newPLoopUbs`,
-/// and `newPLoopStep` respectively. The bounds of the inner loop.for operations
+/// and `newPLoopStep` respectively. The bounds of the inner scf.for operations
 /// to be created are passed in `forLbs`, `forUbs`, and `forStep`. The
 /// `permutation` vector contains a mapping from the original loop order, to the
 /// loop order to be generated.
@@ -70,21 +90,21 @@
                                  ArrayRef<LoopBounds> newPLoopBounds,
                                  ArrayRef<LoopBounds> forBounds,
                                  ArrayRef<unsigned> permutation) {
-  assert(!forBounds.empty() && "unhandled case of no loop.for created");
+  assert(!forBounds.empty() && "unhandled case of no scf.for created");
   unsigned numLoops = pLoopOp.getNumLoops();
   Location loc = pLoopOp.getLoc();
   assert(forBounds.size() + newPLoopBounds.size() == numLoops &&
-         "cannot drop loops when splitting loop.parallel operation");
+         "cannot drop loops when splitting scf.parallel operation");
   assert(permutation.size() == numLoops);
   OpBuilder::InsertionGuard guard(rewriter);
 
-  // Need a signature conversion for the body of the loop.parallel operation,
+  // Need a signature conversion for the body of the scf.parallel operation,
   // before can it can be used as the body of the innermost loop created here.
   TypeConverter::SignatureConversion signatureConverter(numLoops);
   Operation *outermostLoop = nullptr;
   auto permuteIt = permutation.begin();
 
-  // Create the loop.parallel operation as the outermost loop, if specified.
+  // Create the scf.parallel operation as the outermost loop, if specified.
   if (!newPLoopBounds.empty()) {
     auto lbs = llvm::to_vector<2>(llvm::map_range(
         newPLoopBounds, [](LoopBounds bounds) -> Value { return bounds.lb; }));
@@ -101,7 +121,7 @@
     outermostLoop = newPLoop.getOperation();
   }
 
-  // Generate the nested loop.for operations with the bounds passed.
+  // Generate the nested scf.for operations with the bounds passed.
   for (auto it : enumerate(forBounds)) {
     Value lb = it.value().lb, ub = it.value().ub, step = it.value().step;
     if (it.index() != forBounds.size() - 1) {
@@ -110,7 +130,7 @@
       signatureConverter.remapInput(*permuteIt, forOp.getInductionVar());
       rewriter.setInsertionPointToStart(forOp.getBody());
     } else {
-      // For the last loop, move the body of the loop.parallel op as the body of
+      // For the last loop, move the body of the scf.parallel op as the body of
       // the loop after signature conversion.
       auto forOp = buildEmptyForOp(loc, rewriter, lb, ub, step);
       if (!outermostLoop) outermostLoop = forOp.getOperation();
@@ -127,8 +147,8 @@
   return outermostLoop;
 }
 
-/// Serializes the dimensions of the loop.parallel specified in
-/// `serializedDimensions`, by creating an nested loop.for operation for each
+/// Serializes the dimensions of the scf.parallel specified in
+/// `serializedDimensions`, by creating an nested scf.for operation for each
 /// dimension.
 // TODO(ravishankarm): Move this into LoopUtils.h in MLIR.
 static Operation *serializeDimensions(ConversionPatternRewriter &rewriter,
@@ -141,7 +161,7 @@
   serializedDimSet.insert(serializedDimensions.begin(),
                           serializedDimensions.end());
   assert(serializedDimSet.size() == serializedDimensions.size() &&
-         "cannot repeat dimensions during serialization of loop.parallel");
+         "cannot repeat dimensions during serialization of scf.parallel");
   SmallVector<LoopBounds, 2> newPLoopBounds, forBounds;
   SmallVector<unsigned, 2> permutation;
   auto lbs = pLoopOp.lowerBound();
@@ -174,16 +194,85 @@
   return serializeDimensions(rewriter, pLoopOp, serializedDimensions);
 }
 
+/// Collapses all loops in a scf.parallel into one scf.parallel operation. This
+/// is done by
+/// 1) Normalize the loop bounds to be [0, (ub - lb) / step)
+/// 2) Compute the total number of iterations.
+/// 3) From the induction variable of the modified loop, compute the values of
+///    the original induction variables by de-linearization.
+scf::ParallelOp collapseParallelLoops(ConversionPatternRewriter &rewriter,
+                                      scf::ParallelOp pLoopOp) {
+  if (pLoopOp.getNumReductions()) return nullptr;
+
+  unsigned numLoops = pLoopOp.getNumLoops();
+  if (numLoops == 1) return pLoopOp;
+
+  // Compute the number of iterations of each loops starting from the innermost.
+  Location loc = pLoopOp.getLoc();
+  Value totalNumIterations = rewriter.create<ConstantIndexOp>(loc, 1);
+
+  // Track the "stride" of each loop, i.e. product of the total number of
+  // iterations of the inner loops.
+  SmallVector<Value, 2> iterationStride;
+  iterationStride.resize(pLoopOp.getNumLoops());
+  auto lbs = pLoopOp.lowerBound();
+  auto ubs = pLoopOp.upperBound();
+  auto steps = pLoopOp.step();
+  for (int i = numLoops - 1; i >= 0; --i) {
+    Value lb = lbs[i], ub = ubs[i], step = steps[i];
+    Value iterCount = rewriter.create<SignedDivIOp>(
+        loc, rewriter.create<SubIOp>(loc, ub, lb), step);
+    iterationStride[i] = totalNumIterations;
+    totalNumIterations =
+        rewriter.create<MulIOp>(loc, totalNumIterations, iterCount);
+  }
+
+  // Create the collapsed parallel loop op with lowerbound 0, step 1 and upper
+  // bound being the totalNumIterations.
+  Value newLb = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value newStep = rewriter.create<ConstantIndexOp>(loc, 1);
+  scf::ParallelOp newPLoopOp =
+      rewriter.create<scf::ParallelOp>(loc, newLb, totalNumIterations, newStep);
+
+  // Build the body of the collapsed loop by cloning the original loop body. The
+  // replacement value of the induction variables of the original loop body,
+  // from the induction variable of the new loop, using
+  //   origLoopIv[i] = loopIv / iterationStride[i]
+  //   loopIv = loopIv % iterationStride[i]
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block &pLoopBody = pLoopOp.getLoopBody().front();
+  rewriter.setInsertionPointToStart(&newPLoopOp.getLoopBody().front());
+  Value loopIv = *newPLoopOp.getInductionVars().begin();
+  BlockAndValueMapping map;
+  for (int i : llvm::seq<int>(0, numLoops)) {
+    Value iterNum =
+        rewriter.create<SignedDivIOp>(loc, loopIv, iterationStride[i]);
+    Value newIv = rewriter.create<AddIOp>(
+        loc, lbs[i], rewriter.create<MulIOp>(loc, iterNum, steps[i]));
+    map.map(pLoopBody.getArgument(i), newIv);
+    loopIv = rewriter.create<SignedRemIOp>(loc, loopIv, iterationStride[i]);
+  }
+  for (Operation &op : pLoopBody.without_terminator()) {
+    rewriter.clone(op, map);
+  }
+  rewriter.eraseOp(pLoopOp);
+  return newPLoopOp;
+}
+
 //===----------------------------------------------------------------------===//
 // GPU processor ID mapping utilities
 //===----------------------------------------------------------------------===//
 
-/// Distribute loop.parallel to processors with the processors logically
+/// Distributes scf.parallel to processors with the processors logically
 /// arranged with same dimensionality as the number of loops, i.e. a
-/// loop.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
+/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
 /// `numProcessors` must be of same size as the number of loops and are the
 /// values to use for process ID and number of processors along each dimension
 /// in the distributed code.
+/// This method accounts for the case where the number of processors is not
+/// enough to execute the entire iteration space with one iteration mapped to
+/// each processor. So implements a block-cyclic distribution with each block
+/// size being equal to the number of processors.
 static LogicalResult mapToProcessors(ConversionPatternRewriter &rewriter,
                                      scf::ParallelOp pLoopOp,
                                      ArrayRef<Value> processorIDs,
@@ -212,6 +301,43 @@
   return success();
 }
 
+/// Distributes scf.parallel to processors with the processors logically
+/// arranged with same dimensionality as the number of loops, i.e. a
+/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` must be
+/// of same size as the number of loops and are the values to use for process ID
+/// and number of processors along each dimension in the distributed code.  This
+/// method assumes that the number of processors is greater than or equal to the
+/// number of iterations. So just generates an if statement to mask of
+/// processors with no work.
+static LogicalResult mapToProcessorsAndGuard(
+    ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp,
+    ArrayRef<Value> processorIDs) {
+  unsigned numLoops = pLoopOp.getNumLoops();
+  Location loc = pLoopOp.getLoc();
+  assert(numLoops == processorIDs.size() &&
+         "expected as many ids as number of loops");
+  Value cond = nullptr;
+  TypeConverter::SignatureConversion signatureConverter(numLoops);
+  auto lbs = pLoopOp.lowerBound();
+  auto step = pLoopOp.step();
+  auto ubs = pLoopOp.upperBound();
+  for (unsigned i : llvm::seq<unsigned>(0, numLoops)) {
+    Value iterValue = rewriter.create<AddIOp>(
+        loc, lbs[i], rewriter.create<MulIOp>(loc, processorIDs[i], step[i]));
+    Value cmp =
+        rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iterValue, ubs[i]);
+    cond = (cond ? rewriter.create<AndOp>(loc, cond, cmp) : cmp);
+    signatureConverter.remapInput(i, iterValue);
+  }
+  scf::IfOp ifOp = buildEmptyIfOp(loc, rewriter, cond);
+  Region &pLoopOpRegion = pLoopOp.getLoopBody();
+  rewriter.applySignatureConversion(&pLoopOpRegion, signatureConverter);
+  Region &ifOpRegion = ifOp.getRegion(0);
+  rewriter.inlineRegionBefore(pLoopOpRegion, ifOpRegion, ifOpRegion.begin());
+  rewriter.eraseOp(pLoopOp);
+  return success();
+}
+
 namespace {
 struct ProcessorIdAndCount {
   Value id;
@@ -251,7 +377,24 @@
           rewriter.create<MulIOp>(loc, blockDim, gridDim)};
 }
 
-/// Distribute loop.parallel to processors where `IdOp` is used to get the
+template <typename GPUIdOp, typename GPUCountOp>
+static void getGPUProcessorIdsAndCounts(Location loc,
+                                        ConversionPatternRewriter &rewriter,
+                                        unsigned numDims,
+                                        MutableArrayRef<Value> id,
+                                        MutableArrayRef<Value> count) {
+  ArrayRef<StringRef> dims = {"x", "y", "z"};
+  assert(id.size() == numDims);
+  assert(count.size() == numDims);
+  for (unsigned i = 0; i < numDims; ++i) {
+    ProcessorIdAndCount idAndCount =
+        getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
+    id[numDims - 1 - i] = idAndCount.id;
+    count[numDims - 1 - i] = idAndCount.count;
+  }
+}
+
+/// Distributes scf.parallel to processors where `IdOp` is used to get the
 /// processor ID and `DimOp` is used to get the number of processors along a
 /// dimension.
 template <typename GPUIdOp, typename GPUCountOp>
@@ -263,38 +406,51 @@
         cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
     numLoops = 3;
   }
-  SmallVector<Value, 2> id, count;
-  id.reserve(numLoops);
-  count.reserve(numLoops);
-  ArrayRef<StringRef> dims = {"x", "y", "z"};
-  Location loc = pLoopOp.getLoc();
-  for (unsigned i = 0; i < numLoops; ++i) {
-    ProcessorIdAndCount idAndCount =
-        getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
-    id.insert(id.begin(), idAndCount.id);
-    count.insert(count.begin(), idAndCount.count);
-  }
+  SmallVector<Value, 2> id(numLoops), count(numLoops);
+  getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
+                                                   numLoops, id, count);
   return mapToProcessors(rewriter, pLoopOp, id, count);
 }
 
-/// Distribute the loop.parallel to workgroups.
+/// Distributes scf.parallel to processors where `IdOp` is used to get the
+/// processor ID and `DimOp` is used to get the number of processors along a
+/// dimension. Assumes that the number of processors will be less than equal to
+/// the number of iterations of the pLoopOp along all dimensions.
+template <typename GPUIdOp, typename GPUCountOp>
+static LogicalResult mapToProcessorsAndGuard(
+    ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
+  unsigned numLoops = pLoopOp.getNumLoops();
+  if (numLoops > 3) {
+    pLoopOp =
+        cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
+    numLoops = 3;
+  }
+  SmallVector<Value, 2> id(numLoops), count(numLoops);
+  getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
+                                                   numLoops, id, count);
+  return mapToProcessorsAndGuard(rewriter, pLoopOp, id);
+}
+
+/// Distribute the scf.parallel to workgroups.
 static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
                                      scf::ParallelOp pLoopOp) {
   return mapToProcessor<gpu::BlockIdOp, gpu::GridDimOp>(rewriter, pLoopOp);
 }
 
-/// Distribute loop.parallel to workitems using local invocation ID.
+/// Distributes scf.parallel to workitems using local invocation ID.
 static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
                                             scf::ParallelOp pLoopOp) {
-  return mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter, pLoopOp);
+  return mapToProcessorsAndGuard<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter,
+                                                                   pLoopOp);
 }
 
-/// Distribute loop.parallel to workitems using global invocation ID. The GPU
+/// Distributes scf.parallel to workitems using global invocation ID. The GPU
 /// dialect doesn't have a direct operation to do this. This could be done using
 /// id = blockIdx * blockDim + gridIdx. count = blockDim * gridDim.
 static LogicalResult mapToGlobalInvocationId(
     ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
-  return mapToProcessor<GPUGlobalId, GPUGlobalCount>(rewriter, pLoopOp);
+  return mapToProcessorsAndGuard<GPUGlobalId, GPUGlobalCount>(rewriter,
+                                                              pLoopOp);
 }
 
 //===----------------------------------------------------------------------===//
@@ -304,10 +460,13 @@
 namespace {
 /// Pass to convert from tiled and fused linalg ops into gpu.func.
 struct ConvertToGPUPass : public PassWrapper<ConvertToGPUPass, FunctionPass> {
+  ConvertToGPUPass() = default;
+  ConvertToGPUPass(const ConvertToGPUPass &pass) {}
+
   void runOnFunction() override;
 };
 
-/// Pattern to map loop.parallel to workgroups.
+/// Pattern to map scf.parallel to workgroups.
 struct PartitionPLoopToWorkgroups
     : public OpConversionPattern<scf::ParallelOp> {
   using OpConversionPattern<scf::ParallelOp>::OpConversionPattern;
@@ -318,7 +477,7 @@
   }
 };
 
-/// Map tiled linalg op to workitems by lowering it to loop.parallel and
+/// Map tiled linalg op to workitems by lowering it to scf.parallel and
 /// partitioning it to workitems.
 template <typename LinalgOpTy>
 struct MapLinalgOpToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
@@ -342,6 +501,29 @@
   }
 };
 
+/// Legacy path for lowering tiled conv/pooling op to loops.
+// TODO(#2134): Remove this pattern. The default path of using
+// `MapLinalgOpToLocalInvocationId` seems to have a bug. It only shows up
+// currently on Resnet50. Remove this pattern after the bug is triaged/fixed.
+template <typename LinalgOpTy>
+struct MapConvPoolToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
+  using OpConversionPattern<LinalgOpTy>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      LinalgOpTy linalgOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!hasWorkItemMarker(linalgOp)) return failure();
+    Optional<linalg::LinalgLoops> loops =
+        linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
+    if (!loops) return failure();
+    scf::ParallelOp pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
+    if (failed(mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter,
+                                                                pLoopOp)))
+      return failure();
+    rewriter.eraseOp(linalgOp);
+    return success();
+  }
+};
+
 /// Map linalg operation to execute on GPU in parallel by mapping the parallel
 /// loops to "GlobalInvocationId".
 template <typename LinalgOpTy>
@@ -351,19 +533,29 @@
   LogicalResult matchAndRewrite(
       LinalgOpTy linalgOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    // If marker exists and its not no-tile, do nothing.
-    if (hasMarker(linalgOp) && !hasNoTileMarker(linalgOp)) return failure();
+    // If marker exists do nothing.
+    if (hasMarker(linalgOp)) return failure();
     Optional<linalg::LinalgLoops> loops =
         linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
     if (!loops) return failure();
+
+    SmallVector<int64_t, 3> workgroupSize(3, 1);
     if (!loops.getValue().empty()) {
       scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
       // If there are parallel loops partition them to threads using global
       // invocation ID.
-      if (pLoopOp && failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
-        return failure();
+      if (pLoopOp) {
+        pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
+        if (!pLoopOp) return failure();
+        if (failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
+          return rewriter.notifyMatchFailure(
+              linalgOp, "mapping to GlobalInvocationID failed");
+        workgroupSize = {32, 1, 1};
+      }
     }
     rewriter.eraseOp(linalgOp);
+    FuncOp funcOp = linalgOp.template getParentOfType<FuncOp>();
+    if (funcOp) updateWorkGroupSize(funcOp, workgroupSize);
     return success();
   }
 };
@@ -392,7 +584,7 @@
 
   MLIRContext *context = &getContext();
   ConversionTarget target(*context);
-  // After this pass Linalg and loop.parallel ops should be gone.
+  // After this pass Linalg and scf.parallel ops should be gone.
   target.addIllegalOp<scf::ParallelOp>();
   target.addIllegalDialect<linalg::LinalgDialect>();
   // Reshape ops are treated legal since they just change the way the underlying
@@ -404,27 +596,40 @@
 
   OwningRewritePatternList patterns;
 
-  patterns.insert<
   // clang-format off
-#define ADD_ALL_LINALG_PATTERNS(OP_NAME)                                       \
-  MapLinalgOpToGlobalInvocationId<OP_NAME>,                                    \
-  MapLinalgOpToLocalInvocationId<OP_NAME>
+  patterns.insert<
 
-    ADD_ALL_LINALG_PATTERNS(linalg::ConvOp),
+#define ADD_ALL_LINALG_PATTERNS(OP_NAME)                                       \
+    MapLinalgOpToGlobalInvocationId<OP_NAME>,                                  \
+    MapLinalgOpToLocalInvocationId<OP_NAME>
+
     ADD_ALL_LINALG_PATTERNS(linalg::CopyOp),
     ADD_ALL_LINALG_PATTERNS(linalg::FillOp),
     ADD_ALL_LINALG_PATTERNS(linalg::GenericOp),
     ADD_ALL_LINALG_PATTERNS(linalg::IndexedGenericOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::MatmulOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::PoolingMaxOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::PoolingMinOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::PoolingSumOp),
 
 #undef ADD_ALL_LINALG_PATTERNS
 
+#define ADD_ALL_CONV_POOL_PATTERNS(OP_NAME)                                    \
+    MapConvPoolToLocalInvocationId<OP_NAME>,                                   \
+    MapLinalgOpToGlobalInvocationId<OP_NAME>
+
+    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMaxOp),
+    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMinOp),
+    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingSumOp),
+
+#undef ADD_ALL_CONV_POOL_PATTERNS
+
+    MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
     PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
   // clang-format on
 
+  patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::ConvOp>>(context);
+  if (useLegacyConvLowering)
+    patterns.insert<MapConvPoolToLocalInvocationId<linalg::ConvOp>>(context);
+  else
+    patterns.insert<MapLinalgOpToLocalInvocationId<linalg::ConvOp>>(context);
+
   if (failed(applyFullConversion(funcOp, target, patterns)))
     return signalPassFailure();
 }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index aeb0996..399b3f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -21,7 +21,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
 #include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 48e1d57..d50cb40 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -17,9 +17,10 @@
 // Implements a pass to tile and fuse linalg operations on buffers.
 //
 //===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -42,8 +43,6 @@
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-static constexpr unsigned kMaxWorkgroupRank = 3;
-
 static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
   if (vector.empty()) return vector;
   auto numTrailingOnes = 0;
@@ -56,60 +55,14 @@
   return vector.drop_back(numTrailingOnes);
 }
 
-/// Returns the number of "outer" parallel loops specified in the `linalgOp`.
-static unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
-  if (auto convOp = dyn_cast<linalg::ConvOp>(linalgOp.getOperation())) {
-    Optional<DenseIntElementsAttr> padding = convOp.padding();
-    if (padding) return convOp.getNumBatchDimensions();
-  }
-  return linalgOp.iterator_types()
-      .getValue()
-      .take_while([](Attribute attr) {
-        return attr.cast<StringAttr>().getValue() ==
-               getParallelIteratorTypeName();
-      })
-      .size();
-}
-
-/// Updates the workgroup size used for the dispatch region.
-static LogicalResult updateWorkGroupSize(FuncOp funcOp,
-                                         ArrayRef<int64_t> workGroupSize) {
-  // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
-  // attribute, and the hal.executable.
-  Region &body = funcOp.getBody();
-  if (!llvm::hasSingleElement(body))
-    return funcOp.emitError("unhandled dispatch function with multiple blocks");
-
-  SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
-      workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
-
-  // TODO(ravishankarm, antiagainst): We should have at most one scf.parallel
-  // op, but that is not the case till the splitting of kernels lands.
-  unsigned numParallelLoops = 0;
-  auto updateNumParallelLoops = [&numParallelLoops](unsigned nPar) {
-    numParallelLoops =
-        (!numParallelLoops ? nPar : std::min(numParallelLoops, nPar));
-  };
-  for (auto parallelLoop : body.front().getOps<scf::ParallelOp>()) {
-    updateNumParallelLoops(parallelLoop.getNumLoops());
-  }
-  // If there are no parallel loops, there might be linalg ops that arent
-  // tiled. Use that to get the number of parallel loops.
-  for (auto linalgOp : body.front().getOps<linalg::LinalgOp>()) {
-    updateNumParallelLoops(getNumOuterParallelLoops(linalgOp));
-  }
-  workGroupSizeVec.resize(numParallelLoops);
-  LLVM_DEBUG({
-    llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
-    llvm::dbgs() << "# workgroup sizes at end: [";
-    interleaveComma(workGroupSizeVec, llvm::dbgs());
-    llvm::dbgs() << "]\n";
-  });
-  MLIRContext *context = funcOp.getContext();
-  workGroupSizeVec.resize(3, 1);
-  funcOp.setAttr(spirv::getEntryPointABIAttrName(),
-                 spirv::getEntryPointABIAttr(workGroupSizeVec, context));
-  return success();
+/// Returns true if the linalg op has padding attribute, and that it has
+/// non-zero entries.
+template <typename OpTy>
+static bool hasPadding(OpTy op) {
+  Optional<DenseIntElementsAttr> padding = op.padding();
+  if (!padding) return false;
+  return llvm::any_of(padding.getValue(),
+                      [](APInt v) -> bool { return !v.isNullValue(); });
 }
 
 namespace {
@@ -119,7 +72,13 @@
 class TileSizeCalculator {
  public:
   TileSizeCalculator(FuncOp funcOp)
-      : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {}
+      : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
+    if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
+      for (auto val : attr.getValues<APInt>())
+        workgroupSize.push_back(val.getSExtValue());
+    }
+    workgroupSize.resize(3, 1);
+  }
 
   /// Compute the tile sizes based on workgroup size specified.
   LogicalResult setTileSizesBasedOnWorkgroupSize(
@@ -139,21 +98,10 @@
   /// Get the current tile size computed.
   ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
 
-  /// Linalg convention is to use 0 for no tiling. If any of the tile dimensions
-  /// is set to 1 make it 0.
-  SmallVector<int64_t, 3> getTileSizesForLinalg() const {
-    return llvm::to_vector<3>(llvm::map_range(
-        tileSizes, [](int64_t v) -> int64_t { return v == 1 ? 0 : v; }));
-  }
-
   /// Returns the workgroup size to use based on the tile sizes.
   ArrayRef<int64_t> getWorkGroupSize() const { return workgroupSize; }
 
  private:
-  /// Get the default tile sizes based on just number of dimensions, i.e., "x",
-  /// "y", and "z".
-  void setTileSizesBasedOnDimensions(unsigned numDims);
-
   /// Current tile size configuration.
   SmallVector<int64_t, 4> tileSizes;
 
@@ -165,67 +113,72 @@
 };
 }  // namespace
 
-void TileSizeCalculator::setTileSizesBasedOnDimensions(unsigned numDims) {
-  tileSizes.clear();
-  workgroupSize.clear();
-  tileSizes.reserve(3);
-  if (numDims == 0) {
-    // Scalar case.
-    workgroupSize = {1, 1, 1};
-    return;
-  }
-  unsigned maxWorkGroupSize =
-      resourceLimits.max_compute_workgroup_invocations().getInt();
-
-  // Make the tile size 32 along the x-dimension, and then split the remaining
-  // maxWorkGroupSize threads amongst the y-dimension or z-dimension.
-  unsigned tileSizeX = llvm::PowerOf2Floor(std::min(maxWorkGroupSize, 32u));
-  maxWorkGroupSize /= tileSizeX;
-  if (numDims == 1) {
-    tileSizes = {tileSizeX};
-    workgroupSize = {tileSizeX, 1, 1};
-    return;
-  }
-  if (numDims == 2) {
-    unsigned tileSizeY = llvm::PowerOf2Floor(maxWorkGroupSize);
-    tileSizes = {tileSizeY, tileSizeX};
-    workgroupSize = {tileSizeX, tileSizeY, 1};
-    return;
-  }
-  unsigned tileSizeYZ =
-      llvm::PowerOf2Floor(static_cast<unsigned>(std::sqrt(maxWorkGroupSize)));
-  tileSizes = {tileSizeYZ, tileSizeYZ, tileSizeX};
-  workgroupSize = {tileSizeX, tileSizeYZ, tileSizeYZ};
-}
-
 LogicalResult TileSizeCalculator::setTileSizesBasedOnOps(
     ArrayRef<linalg::LinalgOp> linalgOps) {
   tileSizes.clear();
+  if (linalgOps.empty()) {
+    tileSizes = {1, 1, 1};
+    workgroupSize = {1, 1, 1};
+    return success();
+  }
   // The tile size will be driven by operations like matmul, conv, etc. within
   // the list. So see what operation exists in the list to decide the tile size.
   // If there are two such operations in the list, return error.
-  bool hasMatmul = false;
-  unsigned numParallelLoops = kMaxWorkgroupRank;
-  for (linalg::LinalgOp op : linalgOps) {
-    // If there is no marker on this op (i.e. a marker to prevent tile), add an
-    // explicit marker to indicate that the op is to be tiled. Makes subsequent
-    // lowering simpler.
-    if (isa<linalg::MatmulOp>(op.getOperation())) {
-      if (hasMatmul)
-        return op.emitError(
-            "unhandled multiple matmuls within dispatch region");
-      hasMatmul = true;
-    }
-    numParallelLoops = std::min(numParallelLoops, getNumOuterParallelLoops(op));
+  enum OpInfo : uint32_t {
+    None = 0x0,
+    Convolution = 0x1,
+    Matmul = 0x2,
+    Pooling = 0x4,
+  };
+  uint32_t opInfo = OpInfo::None;
+  for (linalg::LinalgOp linalgOp : linalgOps) {
+    Operation *op = linalgOp.getOperation();
+    if (isa<linalg::ConvOp>(op)) opInfo |= OpInfo::Convolution;
+    if (isa<linalg::MatmulOp>(op)) opInfo |= OpInfo::Matmul;
+    if (isa<linalg::PoolingMaxOp>(op)) opInfo |= OpInfo::Pooling;
+    if (isa<linalg::PoolingMinOp>(op)) opInfo |= OpInfo::Pooling;
+    if (isa<linalg::PoolingSumOp>(op)) opInfo |= OpInfo::Pooling;
   }
-  if (hasMatmul) {
+  // If there are no tilable ops, there is nothing to do here.
+  if (!opInfo) return success();
+
+  Operation *linalgOp = *(linalgOps.begin());
+  if (llvm::countPopulation(opInfo) != 1)
+    return linalgOp->getParentOfType<FuncOp>().emitError(
+        "unhandled fusion of ops in dispatch function");
+
+  // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
+  // here for computing tile sizes. In reality we also need the maximum
+  // workgroup memory size available (per workgroup) to compute the tile sizes
+  // effectively.
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  if (opInfo & OpInfo::Convolution) {
+    // TODO(ravishankarm): This tiling is meant to enable promotion to workgroup
+    // memory, but doesnt actually get us to a state where we can do this. The
+    // promotion is possible only when the subviews created are constant
+    // size. For now this doesnt really matter. Revisit this later.
+    int64_t tileSizeX = 32;
+    int64_t tileSizeY = maxWorkgroupSize / 32;
+    tileSizes = {1, tileSizeY, tileSizeX};
+    workgroupSize = {tileSizeX, tileSizeY, 1};
+    return success();
+  }
+  if (opInfo & OpInfo::Matmul) {
     // TODO: For now just hard wire this, but we can do better.
     tileSizes = {8, 8, 4};
     workgroupSize = {8, 8, 1};
     return success();
   }
-  setTileSizesBasedOnDimensions(numParallelLoops);
-  return success();
+  if (opInfo & OpInfo::Pooling) {
+    int64_t tileSizeX = 32;
+    int64_t tileSizeY = maxWorkgroupSize / 32;
+    tileSizes = {tileSizeY, tileSizeX};
+    workgroupSize = {tileSizeX, tileSizeY, 1};
+    return success();
+  }
+  return linalgOp->getParentOfType<FuncOp>().emitError(
+      "unable to find tile size for ops in this dispatch function");
 }
 
 //===----------------------------------------------------------------------===//
@@ -294,22 +247,41 @@
   SmallVector<int64_t, 3> workGroupSize;
 };
 
-/// Pattern to tile linalg operations if they have the workgroup marker.
-template <typename LinalgOp>
-struct TileLinalgOpPattern : public linalg::LinalgTilingPattern<LinalgOp> {
-  using linalg::LinalgTilingPattern<LinalgOp>::LinalgTilingPattern;
+/// Pattern for tiling operations. Updates the workgroup size in the surrounding
+/// function operation if tiling succeeds.
+template <typename OpTy>
+struct TilingPattern : public linalg::LinalgTilingPattern<OpTy> {
+  using Base = linalg::LinalgTilingPattern<OpTy>;
+  TilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+                ArrayRef<int64_t> workgroupSize,
+                linalg::LinalgMarker marker = linalg::LinalgMarker(),
+                PatternBenefit benefit = 1)
+      : Base(context, options, marker, benefit),
+        workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const {
+    // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
+    // erased.
+    FuncOp funcOp = op->getParentOfType<FuncOp>();
+    return failure(!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
+                   failed(updateWorkGroupSize(funcOp, workgroupSize)));
+  }
+
+  SmallVector<int64_t, 3> workgroupSize;
+};
+
+/// Pattern for tiling convolution and pooling operations. Currently is just a
+/// way to not tile when the operation has padding.
+template <typename OpTy>
+struct TileConvPoolPattern : public TilingPattern<OpTy> {
+  using Base = TilingPattern<OpTy>;
+  using Base::TilingPattern;
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (!hasWorkGroupMarker(op)) return failure();
-    if (succeeded(linalg::LinalgTilingPattern<LinalgOp>::matchAndRewrite(
-            op, rewriter)))
-      return success();
-    // Update the marker to map to global invocation ID.
-    rewriter.startRootUpdate(op);
-    setNoTileMarker(op);
-    rewriter.finalizeRootUpdate(op);
-    return success();
+    if (hasPadding(cast<OpTy>(op))) return failure();
+    return Base::matchAndRewrite(op, rewriter);
   }
 };
 
@@ -348,14 +320,7 @@
   auto linalgOps = block.getOps<linalg::LinalgOp>();
   if (linalgOps.empty()) return;
 
-  // Go through all the Linalg ops and set the marker to trigger tiling./
-  // TODO(ravishankarm): Move this to HLOToLinalgOnBuffers so that it is added
-  // on op-creation.
-  for (auto op : linalgOps)
-    if (!hasMarker(op)) setWorkGroupMarker(op);
-
   TileSizeCalculator tileSizeCalculator(funcOp);
-
   if (workGroupSize.empty()) {
     // Get the tile sizes to use for the lowering.
     SmallVector<int64_t, 3> tileSizes;
@@ -376,20 +341,17 @@
   });
 
   OwningRewritePatternList tilingPatterns;
-  tilingPatterns.insert<TileLinalgOpPattern<linalg::ConvOp>,
-                        TileLinalgOpPattern<linalg::CopyOp>,
-                        TileLinalgOpPattern<linalg::FillOp>,
-                        TileLinalgOpPattern<linalg::GenericOp>,
-                        TileLinalgOpPattern<linalg::IndexedGenericOp>,
-                        TileLinalgOpPattern<linalg::MatmulOp>,
-                        TileLinalgOpPattern<linalg::PoolingMaxOp>,
-                        TileLinalgOpPattern<linalg::PoolingMinOp>,
-                        TileLinalgOpPattern<linalg::PoolingSumOp>>(
+  tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+                        TilingPattern<linalg::MatmulOp>,
+                        TileConvPoolPattern<linalg::PoolingMaxOp>,
+                        TileConvPoolPattern<linalg::PoolingMinOp>,
+                        TileConvPoolPattern<linalg::PoolingSumOp>>(
       context,
       linalg::LinalgTilingOptions()
-          .setTileSizes(tileSizeCalculator.getTileSizesForLinalg())
+          .setTileSizes(tileSizeCalculator.getTileSizes())
           .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
-      linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
+      tileSizeCalculator.getWorkGroupSize(),
+      linalg::LinalgMarker(ArrayRef<Identifier>(),
                            Identifier::get(getWorkItemMarker(), context)));
   applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
 
@@ -423,15 +385,6 @@
       insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
     }
   });
-
-  // Update the workgroup size to be consistent with the tile sizes used. Note
-  // the tile sizes are ordered from outer most to inner most loops. The
-  // heuristic is to map the inner loops to x, the next outer (if it exists) to
-  // y, and the next outer (if it exists) to z. So tile sizes are reversed to
-  // get the workgroup size.
-  if (failed(
-          updateWorkGroupSize(funcOp, tileSizeCalculator.getWorkGroupSize())))
-    return signalPassFailure();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
similarity index 87%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index cf641e1..a285236 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
 
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/Attributes.h"
@@ -31,11 +31,9 @@
 static bool checkMarkerValue(Operation *op, StringRef marker = "") {
   StringAttr attr = op->getAttrOfType<StringAttr>(
       linalg::LinalgTransforms::kLinalgTransformMarker);
-  return attr && (marker == "" || attr.getValue() == marker);
+  return attr && (marker.empty() || attr.getValue() == marker);
 }
 
-StringRef getNoTileMarker() { return "no-tile"; }
-
 StringRef getWorkGroupMarker() { return "workgroup"; }
 
 StringRef getWorkItemMarker() { return "workitem"; }
@@ -46,10 +44,6 @@
   return checkMarkerValue(op, marker);
 }
 
-bool hasNoTileMarker(Operation *op) {
-  return checkMarkerValue(op, getNoTileMarker());
-}
-
 bool hasWorkGroupMarker(Operation *op) {
   return checkMarkerValue(op, getWorkGroupMarker());
 }
@@ -69,8 +63,6 @@
               StringAttr::get(marker, op->getContext()));
 }
 
-void setNoTileMarker(Operation *op) { setMarker(op, getNoTileMarker()); }
-
 void setCooperativeMatrixMarker(Operation *op) {
   op->setAttr(VectorTransforms::kVectorTransformMarker,
               StringAttr::get(getCooperativeMatrixMarker(), op->getContext()));
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
similarity index 70%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index fa14263..633bca0 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -19,8 +19,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
-#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
 
 #include "llvm/ADT/StringRef.h"
 #include "mlir/Support/LLVM.h"
@@ -30,55 +30,34 @@
 class Operation;
 namespace iree_compiler {
 
-/// Marker to denote that do not tile the linalg operation.
-StringRef getNoTileMarker();
-
-/// Marker to denote that a linalg operation is to be partitioned to workgroups.
-StringRef getWorkGroupMarker();
-
 /// Marker to denote that a linalg operation is to be partitioned to workitems.
 StringRef getWorkItemMarker();
 
-/// Returns true if an operation has the specified `marker`. When `marker` is
-/// empty, returns true if the operation has any marker.
-bool hasMarker(Operation *, StringRef marker = "");
-
-/// Returns true if an operation has marker to denote that it is not to be
-/// tiled.
-bool hasNoTileMarker(Operation *);
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workgroups.
-bool hasWorkGroupMarker(Operation *);
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workitems.
-bool hasWorkItemMarker(Operation *);
-
 /// Returns true if an operation has a marker to denote that it will be mapped
 /// to cooperative matrix operations. Markers need to be consistent as
 /// cooperative matrices have their own type and load/store operations.
 bool hasCooperativeMatrixMarker(Operation *);
 
-/// Sets a given marker on an operation.
-void setMarker(Operation *, StringRef);
+/// Returns true if an operation has the specified `marker`. When `marker` is
+/// empty, returns true if the operation has any marker.
+bool hasMarker(Operation *, StringRef marker = "");
 
-/// Sets marker to prevent tiling of a linalg operation.
-void setNoTileMarker(Operation *);
-
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workgroups.
-void setWorkGroupMarker(Operation *);
-
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workitems.
-void setWorkItemMarker(Operation *);
+/// Returns true if an operation has marker to denote that it is to be
+/// partitioned to workitems.
+bool hasWorkItemMarker(Operation *);
 
 /// Sets marker to denote that a vector operation is to be execute on a
 /// cooperative matrix.
 void setCooperativeMatrixMarker(Operation *);
 
+/// Sets a given marker on an operation.
+void setMarker(Operation *, StringRef);
+
+/// Sets marker to denote that a linalg operation is to be partitioned to
+/// workitems.
+void setWorkItemMarker(Operation *);
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
-#endif  // IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#endif  // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 4ca5f2f..b1bba5e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -15,11 +15,11 @@
 //===- SplitDispathFunctionPass.cpp ---------------------------------------===//
 //
 // This file implements a pass to split computation workload to multiple
-// sequential dispatch functions. This pass operates on Linalg ops and prepares
-// for lowering to GPU, where we need to tile the workload to workgroups and
-// workitems. If the workload involves computation A and B, where B is
-// dependent on A and A needs all workgroups to complete, then we need
-// to split A and B into different kernels because there is no mechanism
+// sequential dispatch functions. This pass operates on Linalg ops and
+// scf.parallel op and prepares for lowering to GPU, where we need to tile the
+// workload to workgroups and workitems. If the workload involves computation A
+// and B, where B is dependent on A and A needs all workgroups to complete, then
+// we need to split A and B into different kernels because there is no mechanism
 // to perform cross-workgroup synchronization within a single kernel.
 //
 //===----------------------------------------------------------------------===//
@@ -35,6 +35,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -51,24 +52,20 @@
 
 namespace {
 
-/// Returns true if the given `block` contains 0 or 1 Linalg structured ops.
-bool hasZeroOrOneLinalgOp(Block &block) {
-  auto ops = block.getOps<linalg::LinalgOp>();
-  return std::distance(ops.begin(), ops.end()) <= 1;
-}
-
 /// Returns true if the Linalg ops can be separated to multiple kernels.
-bool canSeparateLinalgOps(MutableArrayRef<linalg::LinalgOp> linalgOps) {
-  if (llvm::any_of(linalgOps, [](linalg::LinalgOp op) {
-        return !op.hasBufferSemantics();
+bool canSeparateOps(ArrayRef<Operation *> ops) {
+  if (llvm::any_of(ops, [](Operation *op) {
+        if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
+          return !linalgOp.hasBufferSemantics();
+        return false;
       }))
     return false;
 
   // Require no other ops interleave with Linalg structured ops for now. This is
   // the common case and it simplifies further analysis.
-  for (int i = 0, e = linalgOps.size() - 1; i < e; ++i) {
-    if (linalgOps[i].getOperation()->getNextNode() != linalgOps[i + 1])
-      return false;
+  for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
+       nextOp != ops.end(); ++currOp, ++nextOp) {
+    if ((*currOp)->getNextNode() != *nextOp) return false;
   }
 
   return true;
@@ -144,15 +141,20 @@
     return oldFn.emitError("expected only one block");
   }
 
-  // The dispatch function should have more than one Linalg structured ops.
-  // Otherwise there is nothing to do.
-  if (hasZeroOrOneLinalgOp(oldFn.getBlocks().front())) return success();
+  // The dispatch function should have more than one separable ops. Otherwise
+  // there is nothing to do.
+  Block &fnBody = oldFn.getBlocks().front();
 
-  // Collect all Linalg ops for distributing.
-  SmallVector<linalg::LinalgOp, 4> linalgOps =
-      llvm::to_vector<4>(oldFn.getBlocks().front().getOps<linalg::LinalgOp>());
-  if (!canSeparateLinalgOps(linalgOps)) {
-    return oldFn.emitError("cannot separate Linalg ops into multiple kernels");
+  // Collect all Linalg and scf.parallel ops for distributing.
+  SmallVector<Operation *, 4> separableOps;
+  for (Operation &op : fnBody)
+    if (isa<linalg::LinalgOp>(op) || isa<scf::ParallelOp>(op))
+      separableOps.push_back(&op);
+
+  if (separableOps.size() <= 1) return success();
+  if (!canSeparateOps(separableOps)) {
+    return oldFn.emitError(
+        "cannot separate Linalg/Parallel ops into multiple kernels");
   }
 
   ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp());
@@ -160,13 +162,13 @@
   Location loc = oldFn.getLoc();
 
   SmallVector<std::string, 4> splitKernels;
-  splitKernels.reserve(linalgOps.size());
+  splitKernels.reserve(separableOps.size());
   llvm::SmallPtrSet<Operation *, 16> closure;
 
-  for (const auto &linalgOp : llvm::enumerate(linalgOps)) {
-    // Create a new function for hosting this Linalg op.
-    splitKernels.emplace_back(
-        llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), linalgOp.index()));
+  for (const auto &separableOp : llvm::enumerate(separableOps)) {
+    // Create a new function for hosting this op.
+    splitKernels.emplace_back(llvm::formatv("{0}_dispatch_{1}", oldFn.getName(),
+                                            separableOp.index()));
     StringRef newFnName = splitKernels.back();
     builder.setInsertionPointToStart(moduleOp.getBody());
     auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType(),
@@ -181,7 +183,7 @@
 
     // Collect the closure for the current Linalg op.
     closure.clear();
-    collectAllReferencedOps(linalgOp.value(), closure);
+    collectAllReferencedOps(separableOp.value(), closure);
 
     // Clone all ops in the closure to the new function.
     Block *newFnBlock = newFn.addEntryBlock();
@@ -190,14 +192,14 @@
     for (Operation &op : oldFnBlock) {
       if (closure.count(&op) == 0) continue;
       builder.insert(op.clone(remapper));
-      if (&op == linalgOp.value()) break;
+      if (&op == separableOp.value()) break;
     }
     builder.insert(oldFnBlock.getTerminator()->clone(remapper));
   }
 
   // Add the entry point schedule to the module op.
   SmallVector<Attribute, 4> entryPoints;
-  entryPoints.reserve(linalgOps.size());
+  entryPoints.reserve(separableOps.size());
   for (const std::string &kernel : splitKernels) {
     entryPoints.emplace_back(builder.getStringAttr(kernel));
   }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
new file mode 100644
index 0000000..6fed42e
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -0,0 +1,51 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Utils.cpp - Utility functions used in Linalg to SPIR-V lowering ----===//
+//
+// Implementaiton of utility functions used while lowering from Linalg to SPIRV.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult updateWorkGroupSize(FuncOp funcOp,
+                                  ArrayRef<int64_t> workGroupSize) {
+  // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
+  // attribute, and the hal.executable.
+  Region &body = funcOp.getBody();
+  if (!llvm::hasSingleElement(body))
+    return funcOp.emitError("unhandled dispatch function with multiple blocks");
+
+  if (workGroupSize.size() != 3)
+    return funcOp.emitError("expected workgroup size to have three entries");
+  SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
+      workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
+
+  funcOp.setAttr(
+      spirv::getEntryPointABIAttrName(),
+      spirv::getEntryPointABIAttr(workGroupSizeVec, funcOp.getContext()));
+  return success();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
new file mode 100644
index 0000000..bdea68e
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Utils.h - Utility functions used in Linalg to SPIR-V lowering ------===//
+//
+// Utility functions used while lowering from Linalg to SPIRV.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class FuncOp;
+struct LogicalResult;
+
+namespace iree_compiler {
+
+/// Updates the workgroup size used for the dispatch region.
+LogicalResult updateWorkGroupSize(FuncOp funcOp,
+                                  ArrayRef<int64_t> workGroupSize);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  //  IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 9d81a75..55399f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -1,205 +1,410 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -cse -split-input-file %s | IreeFileCheck %s
 
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
-#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-#map2 = affine_map<(d0, d1) -> (d0, d1)>
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
+                    %arg1 : memref<?x?x?x?xf32>,
+                    %arg2 : memref<?x?x?x?xf32>)
+  attributes {iree.dispatch_fn_name = "parallel_4D"} {
+    linalg.generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [#map0, #map0, #map0],
+       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      %arg0, %arg1, %arg2 {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %0 = addf %arg3, %arg4 : f32
+      linalg.yield %0 : f32
+    } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+    return
+  }
+}
+// CHECK-LABEL: func @parallel_4D
+//  CHECK-SAME:   local_size = dense<[32, 1, 1]>
+//   CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:     %[[C3:.+]] = constant 3 : index
+//   CHECK-DAG:     %[[UB0:.+]] = dim %{{.+}}, %[[C0]]
+//   CHECK-DAG:     %[[UB1:.+]] = dim %{{.+}}, %[[C1]]
+//   CHECK-DAG:     %[[UB2:.+]] = dim %{{.+}}, %[[C2]]
+//   CHECK-DAG:     %[[UB3:.+]] = dim %{{.+}}, %[[C3]]
+//       CHECK:     %[[T4:.+]] = muli %[[UB3]], %[[UB2]]
+//       CHECK:     %[[T5:.+]] = muli %[[T4]], %[[UB1]]
+//       CHECK:     %[[UB:.+]] = muli %[[T5]], %[[UB0]]
+//   CHECK-DAG:     %[[BID:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:     %[[BDIM:.+]] = "gpu.block_dim"() {dimension = "x"}
+//   CHECK-DAG:     %[[TID:.+]] = "gpu.thread_id"() {dimension = "x"}
+//       CHECK:     %[[BOFFSET:.+]] = muli %[[BID]], %[[BDIM]]
+//       CHECK:     %[[IV:.+]] = addi %[[BOFFSET]], %[[TID]]
+//       CHECK:     %[[COND:.+]] = cmpi "slt", %[[IV]], %[[UB]]
+//       CHECK:     scf.if %[[COND]]
+//       CHECK:       %[[IV0:.+]] = divi_signed %[[IV]], %[[T5]]
+//       CHECK:       %[[T14:.+]] = remi_signed %[[IV]], %[[T5]]
+//       CHECK:       %[[IV1:.+]] = divi_signed %[[T14]], %[[T4]]
+//       CHECK:       %[[T16:.+]] = remi_signed %[[T14]], %[[T4]]
+//       CHECK:       %[[IV2:.+]] = divi_signed %[[T16]], %[[UB3]]
+//       CHECK:       %[[IV3:.+]] = remi_signed %[[T16]], %[[UB3]]
+//       CHECK:       load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+//       CHECK:       load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+//       CHECK:       store %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+
+
+// -----
+
+#map0 = affine_map<() -> ()>
+#accesses = [#map0, #map0, #map0]
+#trait = {
+  args_in = 2 : i64,
+  args_out = 1 : i64,
+  indexing_maps = #accesses,
+  iterator_types = []
+}
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
+                   %arg2 : memref<f32>)
+  {
+    linalg.generic #trait %arg0, %arg1, %arg2 {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %0 = addf %arg3, %arg4 : f32
+      linalg.yield %0 : f32
+     } : memref<f32>, memref<f32>, memref<f32>
+     return
+  }
+}
+// CHECK-LABEL: func @scalar_add
+//  CHECK-SAME:   local_size = dense<1> : vector<3xi32>
+//  CHECK-NEXT:     load
+//  CHECK-NEXT:     load
+//  CHECK-NEXT:     addf
+//  CHECK-NEXT:     store
+//  CHECK-NEXT:     return
+
+// -----
 
 module {
-  func @pw_add(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
-               %arg2: memref<4x8xi32>)
-  attributes {iree.dispatch_fn_name = "pw_add"} {
-    %c32 = constant 32 : index
+  func @reduce_sum(%arg0: memref<?x?x?xf32>, %arg1: memref<f32>, %arg2: memref<?xf32>)
+   attributes {iree.dispatch_fn_name = "reduce_sum"} {
+    linalg.indexed_generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>,
+                        affine_map<(d0, d1, d2) -> (d0)>],
+       iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2 {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index,
+         %arg6: f32, %arg7: f32, %arg8: f32):   // no predecessors
+      %c0 = constant 0 : index
+      %cst = constant true
+      %0 = cmpi "eq", %arg5, %c0 : index
+      %1 = and %cst, %0 : i1
+      %2 = select %1, %arg7, %arg8 : f32
+      %3 = addf %arg6, %2 : f32
+      linalg.yield %3 : f32
+    }: memref<?x?x?xf32>, memref<f32>, memref<?xf32>
+    return
+  }
+}
+
+// CHECK-LABEL: func @reduce_sum
+//  CHECK-SAME:   local_size = dense<[32, 1, 1]> : vector<3xi32>
+//   CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+//       CHECK:     %[[UB0:.+]] = dim %{{.+}}, %[[C0]]
+//       CHECK:     %[[UB1:.+]] = dim %{{.+}}, %[[C1]]
+//       CHECK:     %[[UB2:.+]] = dim %{{.+}}, %[[C2]]
+//       CHECK:     %[[UB:.+]] = muli %[[UB1]], %[[UB0]]
+//       CHECK:     %[[COND:.+]] = cmpi "slt", %{{.+}}, %[[UB]]
+//       CHECK:     scf.if %[[COND]]
+//       CHECK:       %[[IV0:.+]] = divi_signed %{{.+}}, %[[UB1]]
+//       CHECK:       %[[IV1:.+]] = remi_signed %{{.+}}, %[[UB1]]
+//       CHECK:       scf.for %[[IV:.+]] = %{{.+}} to %[[UB2]]
+//       CHECK:         %[[ISZERO:.+]] = cmpi "eq", %[[IV]], %[[C0]]
+
+// -----
+
+#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
     %c0 = constant 0 : index
+    %c1 = constant 1 : index
     %c4 = constant 4 : index
     %c8 = constant 8 : index
-    %c1 = constant 1 : index
-    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c4, %c32) {
-      %0 = affine.min #map0(%c4, %c4, %arg3)
-      %1 = affine.min #map0(%c32, %c8, %arg4)
-      %2 = subview %arg0[%arg3, %arg4] [%0, %1] [%c1, %c1]
-             : memref<4x8xi32> to memref<?x?xi32, #map1>
-      %3 = affine.min #map0(%c4, %c4, %arg3)
-      %4 = affine.min #map0(%c32, %c8, %arg4)
-      %5 = subview %arg1[%arg3, %arg4] [%3, %4] [%c1, %c1]
-             : memref<4x8xi32> to memref<?x?xi32, #map1>
-      %6 = affine.min #map0(%c4, %c4, %arg3)
-      %7 = affine.min #map0(%c32, %c8, %arg4)
-      %8 = subview %arg2[%arg3, %arg4] [%6, %7] [%c1, %c1]
-             : memref<4x8xi32> to memref<?x?xi32, #map1>
-      linalg.generic
-        {args_in = 2 : i64, args_out = 1 : i64,
-         indexing_maps = [#map2, #map2, #map2],
-         iterator_types = ["parallel", "parallel"]}
-      {__internal_linalg_transform__ = "workitem"} %2, %5, %8 {
-      ^bb0(%arg5: i32, %arg6: i32, %arg7: i32): // no predecessors
-        %9 = addi %arg5, %arg6 : i32
-        linalg.yield %9 : i32
-      } : memref<?x?xi32, #map1>, memref<?x?xi32, #map1>, memref<?x?xi32, #map1>
+    %0 = dim %arg0, %c0 : memref<?x?xf32>
+    %1 = dim %arg0, %c1 : memref<?x?xf32>
+    %2 = dim %arg1, %c1 : memref<?x?xf32>
+    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %2) step (%c8, %c8) {
+      scf.for %arg5 = %c0 to %1 step %c4 {
+        %3 = affine.min #map0(%arg3)[%0]
+        %4 = affine.min #map1(%arg5)[%1]
+        %5 = subview %arg0[%arg3, %arg5] [%3, %4] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
+        %6 = dim %arg1, %c0 : memref<?x?xf32>
+        %7 = affine.min #map1(%arg5)[%6]
+        %8 = affine.min #map0(%arg4)[%2]
+        %9 = subview %arg1[%arg5, %arg4] [%7, %8] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
+        %10 = dim %arg2, %c0 : memref<?x?xf32>
+        %11 = affine.min #map0(%arg3)[%10]
+        %12 = dim %arg2, %c1 : memref<?x?xf32>
+        %13 = affine.min #map0(%arg4)[%12]
+        %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
+        linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workitem"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+      }
       scf.yield
     }
     return
   }
 }
-//   CHECK-DAG:   %[[STEPY:.+]] = constant 4 : index
-//   CHECK-DAG:   %[[STEPX:.+]] = constant 32 : index
+
+// CHECK-LABEL: func @matmul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//   CHECK-DAG:   %[[C8:.+]] = constant 8 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[GDIMX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-DAG:   %[[GDIMY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+//       CHECK:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C8]]
+//       CHECK:   %[[BSTEPY:.+]] = muli %[[GDIMY]], %[[C8]]
+//       CHECK:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C8]]
+//       CHECK:   %[[BSTEPX:.+]] = muli %[[GDIMX]], %[[C8]]
+//       CHECK:   scf.for %[[BIV0:.+]] = %[[BOFFSETY]] to %[[UB0]] step %[[BSTEPY]]
+//       CHECK:     scf.for %[[BIV1:.+]] = %[[BOFFSETX]] to %[[UB1]] step %[[BSTEPX]]
+//       CHECK:       scf.for %[[BIV2:.+]] = %[[C0]] to %[[UB2]] step %[[C4]]
+//   CHECK-DAG:         %[[VIEWUB0:.+]] = affine.min #{{.+}}(%[[BIV0]])[%[[UB0]]]
+//   CHECK-DAG:         %[[VIEWUB1:.+]] = affine.min #{{.+}}(%[[BIV1]])[%[[UB1]]]
+//   CHECK-DAG:         %[[VIEWUB2:.+]] = affine.min #{{.+}}(%[[BIV2]])[%[[UB2]]]
+//   CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//   CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//       CHECK:         %[[INBOUNDY:.+]] = cmpi "slt", %[[TIDY]], %[[VIEWUB0]]
+//       CHECK:         %[[INBOUNDX:.+]] = cmpi "slt", %[[TIDX]], %[[VIEWUB1]]
+//       CHECK:         %[[COND:.+]] = and %[[INBOUNDY]], %[[INBOUNDX]]
+//       CHECK:         scf.if %[[COND]]
+//       CHECK:           scf.for %{{.+}} = %[[C0]] to %[[VIEWUB2]] step %[[C1]]
+
+// -----
+
+
+#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+    %c4 = constant 4 : index
+    %c32 = constant 32 : index
+    %c2 = constant 2 : index
+    %c0 = constant 0 : index
+    %c3 = constant 3 : index
+    %c1 = constant 1 : index
+    %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
+    %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
+    %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
+    %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
+    %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+    scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
+      %5 = affine.min #map0(%arg3)[%0]
+      %6 = affine.min #map1(%arg4)[%1, %1]
+      %7 = affine.min #map2(%arg5)[%2, %2]
+      %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
+      %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+      %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
+      %11 = affine.min #map0(%arg3)[%10]
+      %12 = affine.min #map4(%arg4)[%3]
+      %13 = affine.min #map5(%arg5)[%4]
+      %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
+      %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+      linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+      scf.yield
+    }
+    return
+  }
+}
+
+// CHECK-LABEL: func @conv_no_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//   CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
+//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-DAG:   %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-//       CHECK:   %[[NEWLBY:.+]] = muli %[[BIDY]], %[[STEPY]]
-//       CHECK:   %[[NEWSTEPY:.+]] = muli %[[NBLOCKSY]], %[[STEPY]]
-//       CHECK:   %[[NEWLBX:.+]] = muli %[[BIDX]], %[[STEPX]]
-//       CHECK:   %[[NEWSTEPX:.+]] = muli %[[NBLOCKSX]], %[[STEPX]]
-//       CHECK:   scf.for %{{.+}} = %[[NEWLBY]] to %{{.+}} step %[[NEWSTEPY]]
-//       CHECK:     scf.for %{{.+}} = %[[NEWLBX]] to %{{.+}} step %[[NEWSTEPX]]
-//   CHECK-DAG:       %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//   CHECK-DAG:       %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//   CHECK-DAG:       %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//   CHECK-DAG:       %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
-//       CHECK:       scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
-//       CHECK:         scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
+//   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//   CHECK-DAG:   %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
+//   CHECK-DAG:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+//   CHECK-DAG:   %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+//   CHECK-DAG:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+//   CHECK-DAG:   %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+//       CHECK:   scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[UB0]] step %[[NBLOCKSZ]]
+//       CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[UB3]] step %[[BSTEPY]]
+//       CHECK:       scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[UB4]] step %[[BSTEPX]]
+//       CHECK:         %[[BOUNDSZ:.+]] = affine.min #{{.+}}(%[[IV3]])
+//       CHECK:         %[[SV1:.+]] = subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//       CHECK:         %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV4]])
+//       CHECK:         %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV5]])
+//       CHECK:         %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//   CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//   CHECK-DAG:         %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//   CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//   CHECK-DAG:         %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//   CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//   CHECK-DAG:         %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"}
+//       CHECK:         scf.for %[[IV3:.+]] = %[[TIDZ]] to %[[BOUNDSZ]] step %[[NTHREADSZ]]
+//       CHECK:           scf.for %[[IV4:.+]] = %[[TIDY]] to %[[BOUNDSY]] step %[[NTHREADSY]]
+//       CHECK:             scf.for %[[IV5:.+]] = %[[TIDX]] to %[[BOUNDSX]] step %[[NTHREADSX]]
+//       CHECK:               scf.for
+//       CHECK:                 scf.for
+//       CHECK:                   scf.for
+//       CHECK:                      scf.for
+//   CHECK-NOT:                        linalg.conv
 
 // -----
 
-module {
-  func @reduce_sum(%arg0: memref<4xf32>, %arg1: memref<f32>, %arg2: memref<f32>)
-   attributes {iree.dispatch_fn_name = "reduce_sum"} {
-    linalg.indexed_generic
-      {args_in = 2 : i64, args_out = 1 : i64,
-       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>,
-                        affine_map<(d0) -> ()>],
-       iterator_types = ["reduction"]} %arg0, %arg1, %arg2 {
-    ^bb0(%arg3: index, %arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
-      %c0 = constant 0 : index
-      %cst = constant true
-      %0 = cmpi "eq", %arg3, %c0 : index
-      %1 = and %cst, %0 : i1
-      %2 = select %1, %arg5, %arg6 : f32
-      %3 = addf %arg4, %2 : f32
-      linalg.yield %3 : f32
-    }: memref<4xf32>, memref<f32>, memref<f32>
+#map0 = affine_map<(d0, d1, d2) -> (32, d1 - d2)>
+#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @conv_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
     return
   }
 }
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-//     CHECK:   scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]]
-// CHECK-NOT:   scf
+
+// CHECK-LABEL: func @conv_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG0]], %[[C2]]
+//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG0]], %[[C3]]
+//   CHECK-DAG:   %[[UB4:.+]] = dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[UB5:.+]] = dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[UB6:.+]] = dim %[[ARG2]], %[[C2]]
+//       CHECK:   %[[T7:.+]] = muli %[[UB3]], %[[UB6]]
+//       CHECK:   %[[T8:.+]] = muli %[[T7]], %[[UB5]]
+//       CHECK:   %[[UB:.+]] = muli %[[T8]], %[[UB4]]
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//   CHECK-DAG:   %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//       CHECK:   %[[T13:.+]] = muli %[[BIDX]], %[[NTHREADSX]]
+//       CHECK:   %[[PROCID:.+]] = addi %[[T13]], %[[TIDX]]
+//       CHECK:   %[[COND:.+]] = cmpi "slt", %[[PROCID]], %[[UB]]
+//       CHECK:   scf.if %[[COND]]
+//       CHECK:     %[[IV0:.+]] = divi_signed %[[PROCID]], %[[T8]]
+//       CHECK:     %[[T17:.+]] = remi_signed %[[PROCID]], %[[T8]]
+//       CHECK:     %[[IV1:.+]] = divi_signed %[[T17]], %[[T7]]
+//       CHECK:     %[[T19:.+]] = remi_signed %[[T17]], %[[T7]]
+//       CHECK:     %[[IV2:.+]] = divi_signed %[[T19]], %[[UB3]]
+//       CHECK:     %[[T21:.+]] = remi_signed %[[T19]], %[[UB3]]
+//       CHECK:     scf.for %[[IV3:.+]] = %[[C0]] to %[[UB2]] step %[[C1]]
+//       CHECK:       scf.for %[[IV4:.+]] = %[[C0]] to %[[UB0]] step %[[C1]]
+//       CHECK:         scf.for %[[IV5:.+]]= %[[C0]] to %[[UB1]] step %[[C1]]
+//   CHECK-NOT:           linalg.conv
+
 
 // -----
 
-#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-#map2 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
-#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map0 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map3 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map4 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
 
-module {
-  func @parallel_4D(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {iree.dispatch_fn_name = "parallel_4D", spv.entry_point_abi = {local_size = dense<[32, 2, 2]> : vector<3xi32>}} {
-    %c2 = constant 2 : index
-    %c32 = constant 32 : index
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @pooling_no_padding(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+    %c4 = constant 4 : index
     %c0 = constant 0 : index
+    %c32 = constant 32 : index
     %c1 = constant 1 : index
-    %c3 = constant 3 : index
-    %0 = dim %arg0, %c0 : memref<?x?x?x?xf32>
-    %1 = dim %arg0, %c1 : memref<?x?x?x?xf32>
-    %2 = dim %arg0, %c2 : memref<?x?x?x?xf32>
-    %3 = dim %arg0, %c3 : memref<?x?x?x?xf32>
-    scf.parallel (%arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0) to (%0, %1, %2, %3) step (%c2, %c2, %c2, %c32) {
-      %12 = affine.min #map0(%arg3)[%0]
-      %13 = affine.min #map0(%arg4)[%1]
-      %14 = affine.min #map0(%arg5)[%2]
-      %15 = affine.min #map1(%arg6)[%3]
-      %16 = subview %arg0[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
-      %17 = subview %arg1[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
-      %18 = subview %arg2[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
-      linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
-        indexing_maps = [#map3, #map3, #map3],
-        iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      {__internal_linalg_transform__ = "workitem"}
-      %16, %17, %18
-      {
-        ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): // no predecessors
-          %19 = addf %arg7, %arg8 : f32
-          linalg.yield %19 : f32
-      } : memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>
+    %0 = dim %arg1, %c0 : memref<?x?xf32>
+    %1 = dim %arg1, %c1 : memref<?x?xf32>
+    %2 = dim %arg2, %c0 : memref<?x?xf32>
+    %3 = dim %arg2, %c1 : memref<?x?xf32>
+    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%2, %3) step (%c4, %c32) {
+      %4 = dim %arg0, %c0 : memref<?x?xf32>
+      %5 = affine.min #map0(%arg3)[%0, %4]
+      %6 = dim %arg0, %c1 : memref<?x?xf32>
+      %7 = affine.min #map1(%arg4)[%1, %6]
+      %8 = subview %arg0[%arg3, %arg4] [%5, %7] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
+      %9 = affine.min #map3(%arg3)[%2]
+      %10 = affine.min #map4(%arg4)[%3]
+      %11 = subview %arg2[%arg3, %arg4] [%9, %10] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
+      linalg.pooling_max(%8, %arg1, %11) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?xf32, #map2>, memref<?x?xf32>, memref<?x?xf32, #map2>
       scf.yield
     }
     return
   }
 }
 
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C32:.+]] = constant 32 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[SERIALDIMOUTER:.+]] = dim %{{.+}}, %[[C3]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"} : () -> index
-// CHECK-DAG: %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"} : () -> index
-// CHECK-DAG: %[[LB0:.+]] = muli %[[BIDZ]], %[[C2]]
-// CHECK-DAG: %[[STEP0:.+]] = muli %[[NBLOCKSZ]], %[[C2]]
-// CHECK-DAG: %[[LB1:.+]] = muli %[[BIDY]], %[[C2]]
-// CHECK-DAG: %[[STEP1:.+]] = muli %[[NBLOCKSY]], %[[C2]]
-// CHECK-DAG: %[[LB2:.+]] = muli %[[BIDX]], %[[C2]]
-// CHECK-DAG: %[[STEP2:.+]] = muli %[[NBLOCKSX]], %[[C2]]
-//     CHECK: scf.for %{{.+}} = %[[LB0]] to %{{.+}} step %[[STEP0]]
-//     CHECK:   scf.for %{{.+}} = %[[LB1]] to %{{.+}} step %[[STEP1]]
-//     CHECK:     scf.for %{{.+}} = %[[LB2]] to %{{.+}} step %[[STEP2]]
-//     CHECK:       scf.for %{{.+}} = %[[C0]] to %[[SERIALDIMOUTER]] step %[[C32]]
-// CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index
-// CHECK-DAG:         %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"} : () -> index
-// CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"} : () -> index
-// CHECK-DAG:         %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"} : () -> index
-// CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"} : () -> index
-// CHECK-DAG:         %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"} : () -> index
-//     CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.+}} step %[[NTHREADSZ]]
-//     CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
-//     CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
-//     CHECK:               scf.for %{{.+}} = %[[C0]] to %{{.+}} step %[[C1]]
-
-// -----
-
-module {
-  func @no_tile(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>,
-                %arg2 : memref<?x?xf32>)
-  attributes {iree.dispatch_fn_name = "reduce_sum"} {
-    linalg.generic
-      {args_in = 2 : i64, args_out = 1 : i64,
-       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>],
-       iterator_types = ["parallel", "parallel"]}
-      {__internal_linalg_tranform__ = "no-tile"} %arg0, %arg1, %arg2 {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
-      %0 = addf %arg3, %arg4 : f32
-      linalg.yield %0 : f32
-    }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
-    return
-  }
-}
-
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = constant 1 : index
-// CHECK-DAG: %[[UBY:.+]] =  dim %{{.*}}, %[[C0]]
-// CHECK-DAG: %[[UBX:.+]] =  dim %{{.*}}, %[[C1]]
-// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[BLOCKSIZEX:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//     CHECK: %[[T6:.+]] =  muli %[[BIDX]], %[[BLOCKSIZEX]]
-//     CHECK: %[[GIDX:.+]] =  addi %[[T6]], %[[TIDX]]
-//     CHECK: %[[NPROCSX:.+]] =  muli %[[BLOCKSIZEX]], %[[NBLOCKSX]]
-// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK-DAG: %[[BLOCKSIZEY:.+]] = "gpu.block_dim"() {dimension = "y"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//     CHECK: %[[T6:.+]] =  muli %[[BIDY]], %[[BLOCKSIZEY]]
-//     CHECK: %[[GIDY:.+]] =  addi %[[T6]], %[[TIDY]]
-//     CHECK: %[[NPROCSY:.+]] =  muli %[[BLOCKSIZEY]], %[[NBLOCKSY]]
-//     CHECK: scf.for %{{.+}} = %[[GIDY]] to %[[UBY]] step %[[NPROCSY]]
-//     CHECK:   scf.for %{{.+}} = %[[GIDX]] to %[[UBX]] step %[[NPROCSX]]
+// CHECK-LABEL: func @pooling_no_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//   CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG2]], %[[C0]]
+//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-DAG:   %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+//   CHECK-DAG:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+//   CHECK-DAG:   %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+//   CHECK-DAG:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+//   CHECK-DAG:   %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+//       CHECK:   scf.for %[[IV3:.+]] = %[[BOFFSETY]] to %[[UB2]] step %[[BSTEPY]]
+//       CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETX]] to %[[UB3]] step %[[BSTEPX]]
+//       CHECK:       %[[SV1:.+]] = subview %[[ARG0]][%[[IV3]], %[[IV4]]]
+//       CHECK:       %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV3]])
+//       CHECK:       %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV4]])
+//       CHECK:       %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]]]
+//   CHECK-DAG:       %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//   CHECK-DAG:       %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//   CHECK-DAG:       %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//   CHECK-DAG:       %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//       CHECK:       scf.for %[[IV5:.+]] = %[[TIDY]] to %[[BOUNDSY]] step %[[NTHREADSY]]
+//       CHECK:         scf.for %[[IV6:.+]] = %[[TIDX]] to %[[BOUNDSX]] step %[[NTHREADSX]]
+//       CHECK:           scf.for
+//       CHECK:             scf.for
+//   CHECK-NOT:               linalg.pooling_max
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
new file mode 100644
index 0000000..be33748
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
@@ -0,0 +1,87 @@
+// RUN: iree-opt -iree-codegen-convert-to-gpu -iree-codegen-use-legacy-conv-lowering=false -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+
+#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+    %c4 = constant 4 : index
+    %c32 = constant 32 : index
+    %c2 = constant 2 : index
+    %c0 = constant 0 : index
+    %c3 = constant 3 : index
+    %c1 = constant 1 : index
+    %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
+    %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
+    %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
+    %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
+    %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+    scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
+      %5 = affine.min #map0(%arg3)[%0]
+      %6 = affine.min #map1(%arg4)[%1, %1]
+      %7 = affine.min #map2(%arg5)[%2, %2]
+      %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
+      %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+      %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
+      %11 = affine.min #map0(%arg3)[%10]
+      %12 = affine.min #map4(%arg4)[%3]
+      %13 = affine.min #map5(%arg5)[%4]
+      %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
+      %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+      linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+      scf.yield
+    }
+    return
+  }
+}
+
+// CHECK-LABEL: func @conv_no_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//   CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
+//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-DAG:   %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+//   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//   CHECK-DAG:   %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
+//   CHECK-DAG:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+//   CHECK-DAG:   %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+//   CHECK-DAG:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+//   CHECK-DAG:   %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+//       CHECK:   scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[UB0]] step %[[NBLOCKSZ]]
+//       CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[UB3]] step %[[BSTEPY]]
+//       CHECK:       scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[UB4]] step %[[BSTEPX]]
+//       CHECK:         %[[BOUNDSZ:.+]] = affine.min #{{.+}}(%[[IV3]])
+//       CHECK:         %[[SV1:.+]] = subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//       CHECK:         %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV4]])
+//       CHECK:         %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV5]])
+//       CHECK:         %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//   CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//   CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//   CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//       CHECK:         %[[INBOUNDSZ:.+]] = cmpi "slt", %[[TIDZ]], %[[BOUNDSZ]]
+//       CHECK:         %[[INBOUNDSY:.+]] = cmpi "slt", %[[TIDY]], %[[BOUNDSY]]
+//       CHECK:         %[[T35:.+]] = and %[[INBOUNDSZ]], %[[INBOUNDSY]]
+//       CHECK:         %[[INBOUNDSX:.+]] = cmpi "slt", %[[TIDX]], %[[BOUNDSX]]
+//       CHECK:         %[[INBOUNDS:.+]] = and %[[T35]], %[[INBOUNDSX]]
+//       CHECK:         scf.if %[[INBOUNDS]]
+//       CHECK:           scf.for
+//       CHECK:             scf.for
+//       CHECK:               scf.for
+//       CHECK:                 scf.for
+//   CHECK-NOT:                   linalg.conv
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 70f3d17..1b5ddda 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -1,65 +1,14 @@
 // RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse %s | IreeFileCheck %s
 
+// Test to check that convolution with padding is not tiled.
 module attributes {
   spv.target_env =
     #spv.target_env<#spv.vce<v1.3,
     [Shader], [SPV_KHR_storage_buffer_storage_class]>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  // CHECK-LABEL: func @tile_only
-  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
-  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
-  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
-  //  CHECK-SAME: local_size = dense<[32, 4, 1]>
-  //       CHECK: scf.parallel
-  //       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]]
-  //       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
-  //       CHECK:   %[[VIEW2:.+]] = subview %[[ARG2]]
-  //       CHECK:   linalg.generic
-  //  CHECK-SAME:     "workitem"
-  //  CHECK-SAME:     %[[VIEW0]]
-  //  CHECK-SAME:     %[[VIEW1]]
-  //  CHECK-SAME:     %[[VIEW2]]
-  func @tile_only(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
-                  %arg2: memref<4x8xi32>) {
-    linalg.generic
-      {args_in = 2 : i64, args_out = 1 : i64,
-       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>],
-       iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
-    ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
-      %0 = addi %arg3, %arg4 : i32
-      linalg.yield %0 : i32
-    }: memref<4x8xi32>, memref<4x8xi32>, memref<4x8xi32>
-    return
-  }
-}
-
-// -----
-
-module attributes {
-  spv.target_env =
-    #spv.target_env<#spv.vce<v1.3,
-    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-    {max_compute_workgroup_invocations = 128 : i32,
-     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  // CHECK-LABEL: func @conv_padding
-  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: local_size = dense<[32, 1, 1]>
-  //       CHECK: scf.parallel (%{{.+}})
-  //       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
-  //       CHECK:   %[[VIEW2:.+]] = subview %[[ARG2]]
-  //       CHECK:   linalg.conv
-  //  CHECK-SAME:     %[[VIEW1]]
-  //  CHECK-SAME:     %[[VIEW2]]
-  //  CHECK-SAME:     "workitem"
   func @conv_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
-                     %arg2 : memref<?x?x?x?xf32>)
-    attributes
-      {iree.dispatch_fn_name = "conv_padding"} {
+                     %arg2 : memref<?x?x?x?xf32>) {
     linalg.conv(%arg0, %arg1, %arg2)
       {dilations = [1, 1],
        padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
@@ -67,6 +16,14 @@
     return
   }
 }
+// CHECK-LABEL: func @conv_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//       CHECK:   linalg.conv
+//  CHECK-SAME:     %[[ARG0]]
+//  CHECK-SAME:     %[[ARG1]]
+//  CHECK-SAME:     %[[ARG2]]
 
 // -----
 
@@ -76,55 +33,24 @@
     [Shader], [SPV_KHR_storage_buffer_storage_class]>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  // CHECK-LABEL: func @conv_no_padding
-  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-  //  CHECK-SAME: local_size = dense<[32, 2, 2]>
-  //       CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-  //       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
-  //       CHECK:   %[[VIEW2:.+]] = subview %[[ARG2]]
-  //       CHECK:   linalg.conv
-  //  CHECK-SAME:     %[[VIEW1]]
-  //  CHECK-SAME:     %[[VIEW2]]
-  //  CHECK-SAME:     "workitem"
   func @conv_no_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
-                        %arg2 : memref<?x?x?x?xf32>)
-    attributes
-      {iree.dispatch_fn_name = "conv_no_padding"} {
+                        %arg2 : memref<?x?x?x?xf32>) {
     linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
       memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
     return
   }
 }
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-module attributes {
-  spv.target_env =
-    #spv.target_env<#spv.vce<v1.3,
-    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-    {max_compute_workgroup_invocations = 128 : i32,
-     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  // CHECK-LABEL: func @parallel_4D
-  //       CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-  func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
-                    %arg1 : memref<?x?x?x?xf32>,
-                    %arg2 : memref<?x?x?x?xf32>)
-  attributes {iree.dispatch_fn_name = "parallel_4D"} {
-    linalg.generic
-      {args_in = 2 : i64, args_out = 1 : i64,
-       indexing_maps = [#map0, #map0, #map0],
-       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      %arg0, %arg1, %arg2 {
-    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
-      %0 = addf %arg3, %arg4 : f32
-      linalg.yield %0 : f32
-    } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-    return
-  }
-}
+// CHECK-LABEL: func @conv_no_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:   local_size = dense<[32, 4, 1]>
+//       CHECK:   scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+//       CHECK:     %[[VIEW1:.+]] = subview %[[ARG1]]
+//       CHECK:     %[[VIEW2:.+]] = subview %[[ARG2]]
+//       CHECK:     linalg.conv
+//  CHECK-SAME:       %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+//  CHECK-SAME:       "workitem"
 
 // -----
 
@@ -134,54 +60,52 @@
     [Shader], [SPV_KHR_storage_buffer_storage_class]>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @no_tile(%arg0: memref<?x?xf32>,
+  func @matmul(%arg0: memref<?x?xf32>,
                 %arg1: memref<?x?xf32>,
                 %ret0: memref<?x?xf32>) {
-    linalg.matmul %arg0, %arg1, %ret0 {__internal_linalg_transform__ = "no-tile"} :
+    linalg.matmul %arg0, %arg1, %ret0 :
       (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
     return
   }
 }
-// CHECK-LABEL: func @no_tile
+
+// CHECK-LABEL: func @matmul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
 //  CHECK-SAME:   local_size = dense<[8, 8, 1]>
-//   CHECK-NOT:   scf
-//       CHECK:   linalg.matmul
-//   CHECK-NOT:   scf
-//       CHECK:   return
+//       CHECK:   scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+//       CHECK:     %[[VIEW0:.+]] = subview %[[ARG0]]
+//       CHECK:     %[[VIEW1:.+]] = subview %[[ARG1]]
+//       CHECK:     %[[VIEW2:.+]] = subview %[[ARG2]]
+//       CHECK:     linalg.matmul
+//  CHECK-SAME:       "workitem"
+//  CHECK-SAME:       %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
 
 // -----
 
-#map0 = affine_map<() -> ()>
-#accesses = [#map0, #map0]
-#trait = {
-  args_in = 2 : i64,
-  args_out = 1 : i64,
-  indexing_maps = #accesses,
-  iterator_types = []
-}
-
 module attributes {
   spv.target_env =
     #spv.target_env<#spv.vce<v1.3,
     [Shader], [SPV_KHR_storage_buffer_storage_class]>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
-                   %arg2 : memref<f32>)
-  {
-    linalg.generic #trait %arg0, %arg1, %arg2 {
-    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
-      %0 = addf %arg3, %arg4 : f32
-      linalg.yield %0 : f32
-     } : memref<f32>, memref<f32>, memref<f32>
-     return
+  func @pooling_sum_no_padding(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+                               %arg2 : memref<?x?xf32>) {
+    linalg.pooling_max(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
+      memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+    return
   }
 }
-// CHECK-LABEL: func @scalar_add
-//   CHECK-NOT:   scf.parallel
-//   CHECK-NOT:   scf.for
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     "no-tile"
-//   CHECK-NOT:   scf.parallel
-//   CHECK-NOT:   scf.for
-//       CHECK:   return
+
+// CHECK-LABEL: func @pooling_sum_no_padding
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+//  CHECK-SAME:   local_size = dense<[32, 4, 1]>
+//       CHECK:   scf.parallel (%{{.+}}, %{{.+}})
+//       CHECK:     %[[VIEW0:.+]] = subview %[[ARG0]]
+//       CHECK:     %[[VIEW2:.+]] = subview %[[ARG2]]
+//       CHECK:     linalg.pooling_max
+//  CHECK-SAME:       %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+//  CHECK-SAME:       "workitem"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 81db628..637fd7b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -46,6 +46,65 @@
 
 // -----
 
+// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
+module {
+  // CHECK: func @kernel_dispatch_2()
+  // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
+  // CHECK:   %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
+  // CHECK:   %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
+  // CHECK:   %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+  // CHECK:   %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
+  // CHECK:   %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+  // CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+  // CHECK:   %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
+  // CHECK:   linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+  // CHECK:   return
+
+  // CHECK: func @kernel_dispatch_1() {
+  // CHECK:   %[[C0:.+]] = constant 0 : index
+  // CHECK:   %[[C1:.+]] = constant 1 : index
+  // CHECK:   scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
+  // CHECK:     scf.yield
+  // CHECK:   return
+
+  // CHECK: func @kernel_dispatch_0()
+  // CHECK:   %[[ZERO:.+]] = constant
+  // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
+  // CHECK:   %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
+  // CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+  // CHECK:   %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
+  // CHECK:   linalg.fill(%[[TS]], %[[ZERO]])
+  // CHECK:   return
+
+  func @kernel() {
+    %cst = constant 0.000000e+00 : f32
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %dim = hal.interface.load.constant offset = 0 : index
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+    %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+    %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
+    linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
+    scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
+      scf.yield
+    }
+    linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+
+// -----
+
 // Nothing to do if there is just one Linalg op.
 
 // CHECK-NOT: vkspv.entry_point_schedule
@@ -71,7 +130,7 @@
 // Do not split when Linalg and non-Linalg ops are interleaving each other.
 
 module {
-  // expected-error @+1 {{cannot separate Linalg ops into multiple kernels}}
+  // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
   func @kernel() {
     %cst = constant 0.000000e+00 : f32
     %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 060dc5a..76cfcb8 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,7 +5,7 @@
     %arg0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<96x96xf32>
     %arg1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<96x96xf32>
     %arg2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<96x96xf32>
-    linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "workgroup"} :
+    linalg.matmul %arg0, %arg1, %arg2 :
       (memref<96x96xf32>, memref<96x96xf32>, memref<96x96xf32>)
     return
   }
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index ac18c5f..c2367a7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -32,13 +32,23 @@
 static llvm::cl::opt<bool> extractPadFromConv(
     "iree-extract-pad-from-conv",
     llvm::cl::desc("Extract padding attributes from conv op"),
-    llvm::cl::init(false));
+    llvm::cl::init(true));
 
 static bool isAllZero(DenseIntElementsAttr attr) {
   if (!attr.isSplat()) return false;
   return attr.getSplatValue<IntegerAttr>().getInt() == 0;
 }
 
+/// Returns true if the linalg op has padding attribute, and that it has
+/// non-zero entries.
+template <typename OpTy>
+static bool hasPadding(OpTy op) {
+  Optional<DenseIntElementsAttr> padding = op.padding();
+  if (!padding) return false;
+  return llvm::any_of(padding.getValue(),
+                      [](APInt v) -> bool { return !v.isNullValue(); });
+}
+
 class ExtractConvOpPaddingAttributes
     : public OpRewritePattern<xla_hlo::ConvOp> {
  public:
@@ -46,7 +56,7 @@
 
   LogicalResult matchAndRewrite(xla_hlo::ConvOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!op.padding()) return failure();
+    if (!hasPadding(op)) return failure();
     auto inputType = op.lhs().getType().cast<ShapedType>();
     int rank = inputType.getRank();
     SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape;
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h
index 981ee72..003fb80 100644
--- a/iree/hal/vmla/op_kernels_ruy.h
+++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -15,10 +15,13 @@
 #ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_
 #define IREE_HAL_VMLA_OP_KERNELS_RUY_H_
 
+#include <type_traits>
+
 #include "absl/base/thread_annotations.h"
 #include "absl/memory/memory.h"
 #include "iree/base/status.h"
 #include "ruy/context.h"
+#include "ruy/mul_params.h"
 #include "ruy/ruy.h"
 
 namespace iree {
@@ -37,6 +40,56 @@
   return absl::make_unique<RuntimeState>();
 }
 
+// Floating-point case.
+template <typename ACC, typename T>
+struct MakeRuyMulParamsImpl {
+  static_assert(std::is_floating_point<ACC>::value, "");
+  static_assert(std::is_floating_point<T>::value, "");
+  static void Run(const MatMul::Buffers<T, ACC>& buffers,
+                  ruy::MulParams<ACC, T>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+  }
+};
+
+// Integer quantized case with downquantization to a destination T narrower than
+// int32.
+template <typename T>
+struct MakeRuyMulParamsImpl<std::int32_t, T> {
+  static_assert(std::is_integral<T>::value, "");
+  static_assert(sizeof(T) < sizeof(std::int32_t), "");
+  static void Run(const MatMul::Buffers<T, std::int32_t>& buffers,
+                  ruy::MulParams<std::int32_t, T>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+    if (buffers.multiplier_mantissa_buffer.size() == 1) {
+      mul_params->set_multiplier_fixedpoint(
+          buffers.multiplier_mantissa_buffer[0]);
+      mul_params->set_multiplier_exponent(
+          buffers.multiplier_exponent_buffer[0]);
+    } else {
+      mul_params->set_multiplier_fixedpoint_perchannel(
+          buffers.multiplier_mantissa_buffer.data());
+      mul_params->set_multiplier_exponent_perchannel(
+          buffers.multiplier_exponent_buffer.data());
+    }
+  }
+};
+
+// Raw integer case with int32 destination. This case does not support any
+// output operation besides bias-addition.
+template <>
+struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t> {
+  static void Run(const MatMul::Buffers<std::int32_t, std::int32_t>& buffers,
+                  ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+  }
+};
+
+template <typename ACC, typename T>
+void MakeRuyMulParams(const MatMul::Buffers<T, ACC>& buffers,
+                      ruy::MulParams<ACC, T>* mul_params) {
+  MakeRuyMulParamsImpl<ACC, T>::Run(buffers, mul_params);
+}
+
 template <typename T, typename ACC>
 Status MatMul::Execute(RuntimeState* runtime_state,
                        const Buffers<T, ACC>& buffers) {
@@ -56,17 +109,7 @@
                         ruy::Order::kColMajor, dst.mutable_layout());
 
   ruy::MulParams<ACC, T> mul_params;
-  mul_params.set_bias(buffers.bias_buffer.data());
-
-  if (buffers.multiplier_mantissa_buffer.size() == 1) {
-    mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]);
-    mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]);
-  } else {
-    mul_params.set_multiplier_fixedpoint_perchannel(
-        buffers.multiplier_mantissa_buffer.data());
-    mul_params.set_multiplier_exponent_perchannel(
-        buffers.multiplier_exponent_buffer.data());
-  }
+  MakeRuyMulParams(buffers, &mul_params);
 
   ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);
 
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index d937456..a3abc81 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -29,14 +29,37 @@
     target_backend = "vulkan-spirv",
 )
 
-# TODO(hanchung): Merge two tests into one single file.
+# TODO(#2345): Merge two tests into one single file.
 iree_check_single_backend_test_suite(
-    name = "check_vulkan-spirv-pad-conv_vulkan",
+    name = "check_vulkan-spirv-split-pad-conv_vulkan",
     srcs = [
         "convolution1.mlir",
         "convolution2.mlir",
     ],
-    compiler_flags = ["-iree-extract-pad-from-conv"],
+    driver = "vulkan",
+    target_backend = "vulkan-spirv",
+)
+
+# TODO(#2345): Merge two tests into one single file.
+iree_check_single_backend_test_suite(
+    name = "check_vulkan-spirv-nosplit-pad-conv_vulkan",
+    srcs = [
+        "convolution1.mlir",
+        "convolution2.mlir",
+    ],
+    compiler_flags = ["-iree-extract-pad-from-conv=false"],
+    driver = "vulkan",
+    target_backend = "vulkan-spirv",
+)
+
+# TODO(#2345): Merge two tests into one single file.
+iree_check_single_backend_test_suite(
+    name = "check_vulkan-spirv-conv-nocontrol_vulkan",
+    srcs = [
+        "convolution1.mlir",
+        "convolution2.mlir",
+    ],
+    compiler_flags = ["-iree-codegen-use-legacy-conv-lowering=false"],
     driver = "vulkan",
     target_backend = "vulkan-spirv",
 )
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 27e65fc..cca6c58 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -28,7 +28,19 @@
 
 iree_check_single_backend_test_suite(
   NAME
-    check_vulkan-spirv-pad-conv_vulkan
+    check_vulkan-spirv-split-pad-conv_vulkan
+  SRCS
+    "convolution1.mlir"
+    "convolution2.mlir"
+  TARGET_BACKEND
+    vulkan-spirv
+  DRIVER
+    vulkan
+)
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_vulkan-spirv-nosplit-pad-conv_vulkan
   SRCS
     "convolution1.mlir"
     "convolution2.mlir"
@@ -37,5 +49,19 @@
   DRIVER
     vulkan
   COMPILER_FLAGS
-    "-iree-extract-pad-from-conv"
+    "-iree-extract-pad-from-conv=false"
+)
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_vulkan-spirv-conv-nocontrol_vulkan
+  SRCS
+    "convolution1.mlir"
+    "convolution2.mlir"
+  TARGET_BACKEND
+    vulkan-spirv
+  DRIVER
+    vulkan
+  COMPILER_FLAGS
+    "-iree-codegen-use-legacy-conv-lowering=false"
 )
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 6f95110..31967b0 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -71,8 +71,7 @@
         "reduce_window.mlir",
         "remainder.mlir",
         "reshape.mlir",
-        # TODO(#1699): Enable after xla_hlo.reverse can be lowered to linalg.
-        # "reverse.mlir",
+        "reverse.mlir",
         "rsqrt.mlir",
         "select.mlir",
         "sine.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index b2eec93..7f65b06 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -56,6 +56,7 @@
     "reduce_window.mlir"
     "remainder.mlir"
     "reshape.mlir"
+    "reverse.mlir"
     "rsqrt.mlir"
     "select.mlir"
     "sine.mlir"
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index 183fc10..f1880fd 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -65,47 +65,51 @@
   return
 }
 
-func @conv2d_2451x2311_same() attributes { iree.module.export } {
-  %inputs = iree.unfoldable_constant dense<[
-      [[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0]],
-       [[ 6.0], [ 7.0], [ 8.0], [ 9.0], [10.0]],
-       [[11.0], [12.0], [13.0], [14.0], [15.0]],
-       [[16.0], [17.0], [18.0], [19.0], [20.0]]],
-      [[[21.0], [22.0], [23.0], [24.0], [25.0]],
-       [[26.0], [27.0], [28.0], [29.0], [30.0]],
-       [[31.0], [32.0], [33.0], [34.0], [35.0]],
-       [[36.0], [37.0], [38.0], [39.0], [40.0]]]]> : tensor <2x4x5x1xf32>
-  %weights = iree.unfoldable_constant dense<[
-      [[[1.0]], [[2.0]], [[3.0]]],
-      [[[4.0]], [[5.0]], [[6.0]]]]> : tensor <2x3x1x1xf32>
-  %res = "xla_hlo.convolution"(%inputs, %weights) {
-       batch_group_count = 1 : i64,
-       dimension_numbers = {
-         input_batch_dimension = 0 : i64,
-         input_feature_dimension = 3 : i64,
-         input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-         kernel_input_feature_dimension = 2 : i64,
-         kernel_output_feature_dimension = 3 : i64,
-         kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-         output_batch_dimension = 0 : i64,
-         output_feature_dimension = 3 : i64,
-         output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
-       feature_group_count = 1 : i64,
-       padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
-       rhs_dilation = dense<1> : tensor<2xi64>,
-       window_strides = dense<1> : tensor<2xi64>} :
-       (tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32>
-  check.expect_almost_eq_const(%res, dense<[
-    [[[ 80.0], [121.0], [142.0], [163.0], [100.0]],
-     [[160.0], [226.0], [247.0], [268.0], [160.0]],
-     [[240.0], [331.0], [352.0], [373.0], [220.0]],
-     [[ 83.0], [104.0], [110.0], [116.0], [ 59.0]]],
-    [[[400.0], [541.0], [562.0], [583.0], [340.0]],
-     [[480.0], [646.0], [667.0], [688.0], [400.0]],
-     [[560.0], [751.0], [772.0], [793.0], [460.0]],
-     [[183.0], [224.0], [230.0], [236.0], [119.0]]]]> : tensor<2x4x5x1xf32>) : tensor<2x4x5x1xf32>
-  return
-}
+// TODO(#2345): This test seems to fail when executed with another
+// test from this file, but passes as a standalone test. Needs further
+// investigation
+
+// func @conv2d_2451x2311_same() attributes { iree.module.export } {
+//   %inputs = iree.unfoldable_constant dense<[
+//       [[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0]],
+//        [[ 6.0], [ 7.0], [ 8.0], [ 9.0], [10.0]],
+//        [[11.0], [12.0], [13.0], [14.0], [15.0]],
+//        [[16.0], [17.0], [18.0], [19.0], [20.0]]],
+//       [[[21.0], [22.0], [23.0], [24.0], [25.0]],
+//        [[26.0], [27.0], [28.0], [29.0], [30.0]],
+//        [[31.0], [32.0], [33.0], [34.0], [35.0]],
+//        [[36.0], [37.0], [38.0], [39.0], [40.0]]]]> : tensor <2x4x5x1xf32>
+//   %weights = iree.unfoldable_constant dense<[
+//       [[[1.0]], [[2.0]], [[3.0]]],
+//       [[[4.0]], [[5.0]], [[6.0]]]]> : tensor <2x3x1x1xf32>
+//   %res = "xla_hlo.convolution"(%inputs, %weights) {
+//        batch_group_count = 1 : i64,
+//        dimension_numbers = {
+//          input_batch_dimension = 0 : i64,
+//          input_feature_dimension = 3 : i64,
+//          input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+//          kernel_input_feature_dimension = 2 : i64,
+//          kernel_output_feature_dimension = 3 : i64,
+//          kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+//          output_batch_dimension = 0 : i64,
+//          output_feature_dimension = 3 : i64,
+//          output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+//        feature_group_count = 1 : i64,
+//        padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+//        rhs_dilation = dense<1> : tensor<2xi64>,
+//        window_strides = dense<1> : tensor<2xi64>} :
+//        (tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32>
+//   check.expect_almost_eq_const(%res, dense<[
+//     [[[ 80.0], [121.0], [142.0], [163.0], [100.0]],
+//      [[160.0], [226.0], [247.0], [268.0], [160.0]],
+//      [[240.0], [331.0], [352.0], [373.0], [220.0]],
+//      [[ 83.0], [104.0], [110.0], [116.0], [ 59.0]]],
+//     [[[400.0], [541.0], [562.0], [583.0], [340.0]],
+//      [[480.0], [646.0], [667.0], [688.0], [400.0]],
+//      [[560.0], [751.0], [772.0], [793.0], [460.0]],
+//      [[183.0], [224.0], [230.0], [236.0], [119.0]]]]> : tensor<2x4x5x1xf32>) : tensor<2x4x5x1xf32>
+//   return
+// }
 
 func @conv2d_no_padding() attributes { iree.module.export } {
   %inputs = iree.unfoldable_constant dense<[
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1becd29..e34523c 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1becd298b82ed2f1a8ba5e61c5ad2ce7fe32d812
+Subproject commit e34523c87c3f1cfabcf741568dede026bbb12d3a