[spirv] Migrate ConvertToGPUPass' invocation tiling logic (#5814)

`ConvertToGPUPass` is a sink for lowering away all Linalg ops
in the SPIR-V pipeline: it can distribute a Linalg op to both
global invocation IDs (if just having one-level distribution) or
local invocation IDs (if having two-level distribution). 

This is mostly historical; now we have more proper layering
in the pipeline where we perform the first-level tiling and
distribution at flow level and the second/third level at
`TileAndVectorizeInOneWorkgroupPass`, the functionality
in `ConvertToGPUPass` is overlapping with that, although
in a complementary way: the tiling in the former pass does
not handle the cases where we cannot do imperfect tiling;
it requires perfectly tiled cases, as we are assuming number
of processors equal to number of iterations to avoid generating
`scf.for` loop from the start. `ConvertToGPUPass` is more
generic and can handle all cases.  This is all opaque and
complicated. 

This commit relaxes the `TileAndVectorizeInOneWorkgroupPass`
to not assume the number of processors equal to the number
of iterations. Now we just tile and cyclically distribute using
`scf.for` loops. This causes issues for perfectly tiled cases
as we need to canonicalize the `affine.min` and one-trip
`scf.for` away to expose static sizes for further vectorization.
That can be done by pulling in additional canonicalization
patterns.

Also in order to utilize `TileAndVectorizeInOneWorkgroupPass`
for the second/third level tiling we need to have the corresponding
scheme in launch configuration for them. This commit also
adds the default second/third level tiling for all now supported
Linalg ops: no tiling on subgroup and tiling to 1 for invocations.
This at the same time helps to clean up a bunch of unwieldy
templated configurations for different ops..

Together, the above migrates the invocation tiling logic
in `ConvertToGPUPass` to their proper places. This is the
first step as reining in `ConvertToGPUPass` and launch
configurations.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index ef56ea1..a28943e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -389,24 +389,6 @@
 
 /// Distributes scf.parallel to processors where `IdOp` is used to get the
 /// processor ID and `DimOp` is used to get the number of processors along a
-/// dimension.
-template <typename GPUIdOp, typename GPUCountOp>
-static LogicalResult distributeCyclicallyToProcessors(
-    ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
-  unsigned numLoops = pLoopOp.getNumLoops();
-  if (numLoops > 3) {
-    pLoopOp =
-        cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
-    numLoops = 3;
-  }
-  SmallVector<linalg::ProcInfo, 2> procInfo =
-      getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(
-          rewriter, pLoopOp.getLoc(), numLoops);
-  return distributeCyclicallyToProcessors(rewriter, pLoopOp, procInfo);
-}
-
-/// Distributes scf.parallel to processors where `IdOp` is used to get the
-/// processor ID and `DimOp` is used to get the number of processors along a
 /// dimension. Assumes that the number of processors will be less than equal to
 /// the number of iterations of the pLoopOp along all dimensions.
 template <typename GPUIdOp, typename GPUCountOp>
@@ -425,34 +407,6 @@
                                                generateGuard);
 }
 
-/// Distribute the scf.parallel to workgroups.
-static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
-                                     scf::ParallelOp pLoopOp,
-                                     bool useCyclicDistribution = false) {
-  if (useCyclicDistribution) {
-    return distributeCyclicallyToProcessors<gpu::BlockIdOp, gpu::GridDimOp>(
-        rewriter, pLoopOp);
-  }
-  return distributeSingleIterationPerProcessor<gpu::BlockIdOp, gpu::GridDimOp>(
-      rewriter, pLoopOp, false);
-}
-
-/// Distributes scf.parallel to workitems using local invocation ID.
-static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
-                                            scf::ParallelOp pLoopOp) {
-  return distributeCyclicallyToProcessors<gpu::ThreadIdOp, gpu::BlockDimOp>(
-      rewriter, pLoopOp);
-}
-
-/// Distributes scf.parallel to workitems using global invocation ID. The GPU
-/// dialect doesn't have a direct operation to do this. This could be done using
-/// id = blockIdx * blockDim + gridIdx. count = blockDim * gridDim.
-static LogicalResult mapToGlobalInvocationId(
-    ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
-  return distributeSingleIterationPerProcessor<GPUGlobalId, GPUGlobalCount>(
-      rewriter, pLoopOp);
-}
-
 /// Returns the number of bytes copied when loading to/storing from workgorup
 /// memory. It is approximated to be the size of the underlying allocation being
 /// copied into/from.
@@ -502,34 +456,6 @@
   void runOnOperation() override;
 };
 
-struct SerializeParallelLoopPattern
-    : public OpConversionPattern<scf::ParallelOp> {
-  using OpConversionPattern<scf::ParallelOp>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      scf::ParallelOp pLoopOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter &rewriter) const override {
-    return success(serializeDimensionsFrom(rewriter, pLoopOp, 0) != nullptr);
-  }
-};
-
-/// Implementation of the mapping of tiled linalg op to workitems within a
-/// workgroup.
-template <typename LinalgOpTy>
-static LogicalResult mapLinalgOpToLocalInvocationIdImpl(
-    LinalgOpTy linalgOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) {
-  // Check for marker that specifies that the linalg op is to be partitioned
-  // across threads within a workgroup.
-  if (!hasMarker(linalgOp)) return failure();
-  Optional<linalg::LinalgLoops> loops =
-      linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
-  if (!loops) return failure();
-  if (loops.getValue().empty()) return success();
-
-  auto pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
-  return mapToLocalInvocationId(rewriter, pLoopOp);
-}
-
 static LogicalResult distributeCopyOp(linalg::CopyOp copyOp,
                                       scf::ParallelOp pLoopOp,
                                       ConversionPatternRewriter &rewriter) {
@@ -562,52 +488,36 @@
 // in mods/divs in the collapsed loop body. This can be removed by reshaping the
 // copy to be a 1D copy. This seems to be hitting an error in reshape
 // canonicalization. Investigate this further.
-template <>
-LogicalResult mapLinalgOpToLocalInvocationIdImpl<linalg::CopyOp>(
-    linalg::CopyOp copyOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) {
-  if (!hasMarker(copyOp,
-                 {getCopyToWorkgroupMemoryMarker(), getWorkgroupMarker()}))
-    return failure();
-  Optional<linalg::LinalgLoops> loops =
-      linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, copyOp);
-  if (!loops) return failure();
-  if (loops.getValue().empty()) return success();
-
-  auto pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
-  if (hasMarker(copyOp, getWorkgroupMarker())) {
-    return mapToLocalInvocationId(rewriter, pLoopOp);
-  }
-  return distributeCopyOp(copyOp, pLoopOp, rewriter);
-}
-
-/// Map tiled linalg op to workitems by lowering it to scf.parallel and
-/// partitioning it to workitems.
-template <typename LinalgOpTy>
-struct MapLinalgOpToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
-  MapLinalgOpToLocalInvocationId(MLIRContext *context,
-                                 PatternBenefit benefit = 1)
-      : OpConversionPattern<LinalgOpTy>(context, benefit) {}
+struct SerializeAndDistributeCopy : public OpConversionPattern<linalg::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult matchAndRewrite(
-      LinalgOpTy linalgOp, ArrayRef<Value> operands,
+      linalg::CopyOp copyOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    if (failed(
-            mapLinalgOpToLocalInvocationIdImpl(linalgOp, operands, rewriter)))
+    if (!hasMarker(copyOp, {getCopyToWorkgroupMemoryMarker()}))
       return failure();
 
-    // If the `linalgOp` writes to workgroup memory insert barrier after the
+    Optional<linalg::LinalgLoops> loops =
+        linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, copyOp);
+    if (!loops) return failure();
+    if (!loops.getValue().empty()) {
+      auto pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
+      if (failed(distributeCopyOp(copyOp, pLoopOp, rewriter))) return failure();
+    }
+
+    // If the `copyOp` writes to workgroup memory insert barrier after the
     // op.
-    if (llvm::any_of(linalgOp.getOperands(), [](Value output) {
+    if (llvm::any_of(copyOp.getOperands(), [](Value output) {
           MemRefType outputType = output.getType().dyn_cast<MemRefType>();
           return outputType &&
                  outputType.getMemorySpaceAsInt() == getWorkgroupMemorySpace();
         })) {
       rewriter.create<spirv::ControlBarrierOp>(
-          linalgOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
+          copyOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
           spirv::MemorySemantics::AcquireRelease);
     }
-    rewriter.eraseOp(linalgOp);
+
+    rewriter.eraseOp(copyOp);
     return success();
   }
 };
@@ -654,9 +564,12 @@
       if (pLoopOp) {
         pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
         if (!pLoopOp) return failure();
-        if (failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
+        if (failed(distributeSingleIterationPerProcessor<GPUGlobalId,
+                                                         GPUGlobalCount>(
+                rewriter, pLoopOp))) {
           return rewriter.notifyMatchFailure(
               linalgOp, "mapping to GlobalInvocationID failed");
+        }
         workgroupSize = {32, 1, 1};
       }
     }
@@ -678,22 +591,11 @@
   }
 };
 
-/// Remove the linalg.range operation created when lowering to loops.
-struct RemoveLinalgRange : public OpConversionPattern<linalg::RangeOp> {
-  using OpConversionPattern<linalg::RangeOp>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      linalg::RangeOp rangeOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter &rewriter) const override {
-    if (!rangeOp.getResult().use_empty()) return failure();
-    rewriter.eraseOp(rangeOp);
-    return success();
-  }
-};
 }  // namespace
 
 // Applies tiling followed to load/store optimized size then distribute on
 // incovations.
-static LogicalResult linalgCopyTileAndDistribute(
+static LogicalResult tileAndDistributeCopy(
     linalg::CopyOp copyOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) {
   linalg::LinalgTilingOptions options;
@@ -705,8 +607,8 @@
   unsigned numElement = vecLoadBits / elementBits;
   options.setTileSizes({1, numElement})
       .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
-  Optional<linalg::TiledLinalgOp> tiledOp = linalg::tileLinalgOp(
-      rewriter, cast<linalg::LinalgOp>(copyOp.getOperation()), options);
+  Optional<linalg::TiledLinalgOp> tiledOp =
+      linalg::tileLinalgOp(rewriter, copyOp, options);
   if (!tiledOp) return failure();
   if (tiledOp->loops.empty()) return success();
   setMarker(tiledOp->op, getVectorizeMarker());
@@ -724,7 +626,7 @@
     if (!hasMarker(linalgOp, getCopyToWorkgroupMemoryMarker())) {
       return failure();
     }
-    if (failed(linalgCopyTileAndDistribute(linalgOp, operands, rewriter))) {
+    if (failed(tileAndDistributeCopy(linalgOp, operands, rewriter))) {
       return failure();
     }
 
@@ -743,7 +645,7 @@
 };
 }  // namespace
 
-void populateLinalgTileAndDistributePatterns(
+void populateTileAndDistributeLinalgCopyPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
   patterns.insert<TileAndDistributeCopyOp>(context);
 }
@@ -763,24 +665,11 @@
 
   OwningRewritePatternList patterns(&getContext());
 
-  patterns.insert<
-      MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
-      MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
-      MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
-      MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
-      MapLinalgOpToLocalInvocationId<linalg::ConvInputNWCFilterWCFOp>,
-      MapLinalgOpToLocalInvocationId<linalg::ConvInputNHWCFilterHWCFOp>,
-      MapLinalgOpToLocalInvocationId<linalg::ConvInputNDHWCFilterDHWCFOp>,
-      MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
-      MapLinalgOpToLocalInvocationId<linalg::FillOp>,
-      MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
-      MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
-      MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
-      MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
-      MapLinalgOpToLocalInvocationId<linalg::PoolingNHWCMaxFOp>,
-      MapLinalgOpToLocalInvocationId<linalg::PoolingNHWCMinFOp>,
-      MapLinalgOpToLocalInvocationId<linalg::PoolingNHWCSumFOp>,
-      RemoveLinalgRange, SerializeParallelLoopPattern>(context);
+  patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+                  SerializeAndDistributeCopy>(context);
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
   for (FuncOp funcOp : getOperation().getInnerModule().getOps<FuncOp>()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
index 53de3b4..c9aa605 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
@@ -101,7 +101,7 @@
       break;
     }
   }
-  if (!entryPointOp) return {};
+  if (!entryPointOp || !entryPointOp.getBody()) return {};
 
   Operation *terminator = entryPointOp.getBlock()->getTerminator();
   auto retOp = dyn_cast<IREE::HAL::ReturnOp>(terminator);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 05a6b9e..23daa25 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -63,6 +63,47 @@
   return std::min(shape, tileSize);
 }
 
+/// Sets the `tileSizes` and `workgroupSize` for an Linalg `op` to the default,
+/// where at most 3 inner parallel dimensions of `op` are tiled and distributed,
+/// and each invocation handles one scalar elements.
+// TODO(#5852): revisit the default here: they were chosen to get started and
+// not very good.
+static LogicalResult setDefaultTilingScheme(
+    const spirv::TargetEnv &targetEnv, linalg::LinalgOp op,
+    TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize) {
+  auto maxWorkgroupSize =
+      targetEnv.getResourceLimits().max_compute_workgroup_invocations();
+
+  const int64_t tileSizeX = 32;
+  const int64_t tileSizeY = maxWorkgroupSize.getInt() / tileSizeX;
+
+  unsigned numParallelDims = getNumOuterParallelLoops(op);
+
+  SmallVector<int64_t, 4> workgroupLevel(numParallelDims, 0);
+  SmallVector<int64_t, 4> invocationLevel(numParallelDims, 0);
+
+  if (numParallelDims >= 1) {
+    workgroupLevel.back() = tileSizeX;
+    invocationLevel.back() = 1;
+  }
+  if (numParallelDims >= 2) {
+    workgroupLevel[numParallelDims - 2] = tileSizeY;
+    invocationLevel[numParallelDims - 2] = 1;
+  }
+  if (numParallelDims >= 3) {
+    workgroupLevel[numParallelDims - 3] = 1;
+    invocationLevel[numParallelDims - 3] = 1;
+  }
+
+  tileSizes.emplace_back(std::move(workgroupLevel));
+  tileSizes.emplace_back();  // Subgroup level
+  tileSizes.emplace_back(std::move(invocationLevel));
+
+  workgroupSize = {tileSizeX, tileSizeY, 1};
+
+  return success();
+}
+
 /// Fills `inputTypes` and `outputTypes` with the original input/output types
 /// for all tiles for `op`.
 static std::tuple<SmallVector<ShapedType>, SmallVector<ShapedType>>
@@ -105,7 +146,7 @@
                                        const SPIRVCodegenOptions &options,
                                        TileSizesListType &tileSizes,
                                        LaunchConfigInfo &config) {
-  return op.emitError("undefined launch config for tiled operation");
+  return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
 }
 
 static void getMaliBestMatMulTileSizes(
@@ -225,11 +266,15 @@
     // available (maybe). For now, just hard-wire it.
     tileSizeK = 32;
   }
-  assert(tileSizes.empty());
-  SmallVector<int64_t, 4> ts = {
+  SmallVector<int64_t, 4> workgroupLevel = {
       nBatchesPerWorkitem, nRowsPerWorkitem * config.workgroupSize[1],
       nColsPerWorkitem * config.workgroupSize[0], tileSizeK};
-  tileSizes.emplace_back(std::move(ts));
+  SmallVector<int64_t, 4> invocationLevel = {
+      nBatchesPerWorkitem, nRowsPerWorkitem, nColsPerWorkitem, 0};
+
+  tileSizes.emplace_back(std::move(workgroupLevel));
+  tileSizes.emplace_back();  // subgroup level
+  tileSizes.emplace_back(std::move(invocationLevel));
   return success();
 }
 
@@ -366,23 +411,21 @@
     break;
   }
   unsigned numLoops = getNumOuterParallelLoops(linalgOp);
-  SmallVector<int64_t, 4> ts;
-  ts.resize(numLoops, 1);
+  SmallVector<int64_t, 4> ts(numLoops, 1);
   ts.back() = lowerTs;
-  tileSizes.emplace_back(ts);  // Workgroup level.
-  // If the shape is not exactly aligned on the tile size skip the second level
-  // of tiling as it expect the number of iteration to be exactly equal to the
-  // number of processors.
-  if (!vectorize || outputShape.getShape().back() % lowerTs != 0) {
-    config.vectorize = false;
-    return success();
-  }
+  tileSizes.emplace_back(ts);  // Workgroup level
+  tileSizes.emplace_back();    // Subgroup level
 
-  tileSizes.emplace_back();  // Subgroup level.
-  ts.back() = lowerTs / subgroupSize;
-  tileSizes.emplace_back(ts);  // Thread level.
-  // Vectorize only if we are processing more than one element per thread.
-  config.vectorize = vectorize && (ts.back() > 1);
+  if (!vectorize || outputShape.getShape().back() % lowerTs != 0) {
+    ts.back() = 1;
+    tileSizes.emplace_back(ts);  // Thread level
+    config.vectorize = false;
+  } else {
+    ts.back() = lowerTs / subgroupSize;
+    tileSizes.emplace_back(ts);  // Thread level
+    // Vectorize only if we are processing more than one element per thread.
+    config.vectorize = vectorize && (ts.back() > 1);
+  }
   return success();
 }
 
@@ -486,22 +529,21 @@
   int64_t N = inputTypes[1].getShape()[1];
   int64_t K = inputTypes[0].getShape()[1];
 
-  SmallVector<int64_t, 4> ts = {
+  SmallVector<int64_t, 4> workgroupLevel = {
       getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
       getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
       getMinIfShapeStatic(K, tileSizeK)};
-  assert(tileSizes.empty());
-  tileSizes.emplace_back(std::move(ts));
+  SmallVector<int64_t, 4> invocationLevel = {1, 1, 0};
+
+  tileSizes.emplace_back(std::move(workgroupLevel));
+  tileSizes.emplace_back();  // subgroup level
+  tileSizes.emplace_back(std::move(invocationLevel));
   return success();
 }
 
-template <typename ConvOpTy>
-static LogicalResult getMaliSpecificConfig(ConvOpTy op,
+static LogicalResult getMaliSpecificConfig(linalg::ConvInputNHWCFilterHWCFOp op,
                                            TileSizesListType &tileSizes,
                                            LaunchConfigInfo &config) {
-  Operation *operation = op.getOperation();
-  if (!isa<linalg::ConvInputNHWCFilterHWCFOp>(operation)) return failure();
-
   SmallVector<ShapedType> inputTypes, outputTypes;
   std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
 
@@ -565,43 +607,20 @@
   return failure();
 }
 
-template <typename T>
-LogicalResult getConvOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
-                                    const SPIRVCodegenOptions &options,
-                                    TileSizesListType &tileSizes,
-                                    LaunchConfigInfo &config) {
+template <>
+LogicalResult getOpLaunchConfig(linalg::ConvInputNHWCFilterHWCFOp op,
+                                const spirv::TargetEnv &targetEnv,
+                                const SPIRVCodegenOptions &options,
+                                TileSizesListType &tileSizes,
+                                LaunchConfigInfo &config) {
   if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
       succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
     return success();
   }
 
-  unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
-                                  .max_compute_workgroup_invocations()
-                                  .getInt();
-  const int64_t tileSizeX = 32;
-  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
-  SmallVector<int64_t, 4> ts;
-  ts.assign({0, 1, tileSizeY, tileSizeX});
-  tileSizes.emplace_back(std::move(ts));
-  config.workgroupSize = {tileSizeX, tileSizeY, 1};
-  return success();
+  return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
 }
 
-#define GET_CONV_LAUNCH_CONFIG(opType)                                       \
-  template <>                                                                \
-  LogicalResult getOpLaunchConfig(                                           \
-      opType op, const spirv::TargetEnv &targetEnv,                          \
-      const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,      \
-      LaunchConfigInfo &config) {                                            \
-    return getConvOpLaunchConfig(op, targetEnv, options, tileSizes, config); \
-  }
-
-GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNWCFilterWCFOp)
-GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNHWCFilterHWCFOp)
-GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp)
-
-#undef GET_CONV_LAUNCH_CONFIG
-
 static LogicalResult getMaliSpecificConfig(
     linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
     LaunchConfigInfo &config) {
@@ -673,77 +692,9 @@
     return success();
   }
 
-  unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
-                                  .max_compute_workgroup_invocations()
-                                  .getInt();
-  const int64_t tileSizeX = 32;
-  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
-  SmallVector<int64_t, 4> ts;
-  ts.assign({0, 1, tileSizeY, tileSizeX});
-  tileSizes.emplace_back(std::move(ts));
-  config.workgroupSize = {tileSizeX, tileSizeY, 1};
-  return success();
+  return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
 }
 
-template <>
-LogicalResult getOpLaunchConfig(linalg::DepthwiseConvInputNHWCFilterHWCFOp op,
-                                const spirv::TargetEnv &targetEnv,
-                                const SPIRVCodegenOptions &options,
-                                TileSizesListType &tileSizes,
-                                LaunchConfigInfo &config) {
-  unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
-                                  .max_compute_workgroup_invocations()
-                                  .getInt();
-  const int64_t tileSizeX = 32;
-  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
-  SmallVector<int64_t, 4> ts;
-  // There are five parallel loops in depthwise_conv_2d_input_nhwc_filter_hwcf
-  ts.assign({0, 0, 1, tileSizeY, tileSizeX});
-  tileSizes.emplace_back(std::move(ts));
-  config.workgroupSize = {tileSizeX, tileSizeY, 1};
-  return success();
-}
-
-template <typename PoolingOpTy>
-static LogicalResult getPoolingOpLaunchConfig(
-    PoolingOpTy op, const spirv::TargetEnv &targetEnv,
-    const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
-    LaunchConfigInfo &config) {
-  unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
-                                  .max_compute_workgroup_invocations()
-                                  .getInt();
-  // Pooling op seems to be rank polymorphic but is not well specified enough to
-  // be able to figure out which dimensions of the output correspond to the
-  // pooled dimension and which are not. Need to fix that, but for now just use
-  // a working heuristic.
-  const int64_t tileSizeX = 32;
-  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
-  SmallVector<int64_t, 4> ts;
-  ts.assign({0, tileSizeY, tileSizeX, 1});
-  tileSizes.emplace_back(std::move(ts));
-  config.workgroupSize = {tileSizeX, tileSizeY, 1};
-  return success();
-}
-
-#define DEFINE_POOLING_OP_CONFIG(opName)                                \
-  template <>                                                           \
-  LogicalResult getOpLaunchConfig(                                      \
-      opName op, const spirv::TargetEnv &targetEnv,                     \
-      const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \
-      LaunchConfigInfo &config) {                                       \
-    return getPoolingOpLaunchConfig(op, targetEnv, options, tileSizes,  \
-                                    config);                            \
-  }
-
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCMaxI8Op)
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCMaxI16Op)
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCMaxI32Op)
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCMaxFOp)
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCMinFOp)
-DEFINE_POOLING_OP_CONFIG(linalg::PoolingNHWCSumFOp)
-
-#undef DEFINE_POOLINGOP_CONFIG
-
 Optional<LaunchConfig> initGPULaunchConfig(
     MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
     const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 82e5fca..96fdda2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -90,8 +90,8 @@
 // Patterns
 //===----------------------------------------------------------------------===//
 
-/// Populates patterns to tile and distribute linalg operations.
-void populateLinalgTileAndDistributePatterns(
+/// Populates patterns to tile and distribute linalg.copy operations.
+void populateTileAndDistributeLinalgCopyPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns);
 
 /// Populates patterns to fold processor ID uses by using processor counts
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
index 10c0237..c48de73 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
@@ -74,6 +74,11 @@
       markers, Identifier::get(replaceMarker, context));
 }
 
+/// Converts a symbolic GPU processor dimension to its numeric one.
+static unsigned dimToIndex(StringRef dim) {
+  return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2);
+}
+
 //===----------------------------------------------------------------------===//
 // Main pass
 //===----------------------------------------------------------------------===//
@@ -270,20 +275,27 @@
   };
   linalg::LinalgLoopDistributionOptions invocationDistributionOptions = {
       getThreadProcInfoFn,
-      {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
-       linalg::DistributionMethod::CyclicNumProcsEqNumIters,
-       linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+      {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+       linalg::DistributionMethod::Cyclic}};
 
   auto tilingOptions =
       linalg::LinalgTilingOptions()
-          .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+          .setLoopType(linalg::LinalgTilingLoopType::Loops)
           .setTileSizeComputationFunction(getInnerTileSizeFn)
           .setDistributionOptions(invocationDistributionOptions);
 
-  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::FillOp>,
-                  linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::GenericOp>>(
+  patterns.insert<
+      linalg::LinalgTilingPattern<linalg::MatmulOp>,
+      linalg::LinalgTilingPattern<linalg::FillOp>,
+      linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
+      linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+      linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
+      linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+      linalg::LinalgTilingPattern<linalg::GenericOp>,
+      linalg::LinalgTilingPattern<linalg::IndexedGenericOp>,
+      linalg::LinalgTilingPattern<linalg::PoolingNHWCMaxFOp>,
+      linalg::LinalgTilingPattern<linalg::PoolingNHWCMinFOp>,
+      linalg::LinalgTilingPattern<linalg::PoolingNHWCSumFOp>>(
       context, tilingOptions,
       getLinalgMatchAndReplaceMarker(
           {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
@@ -298,6 +310,27 @@
           getConvFilterTileMarker(), context));
 }
 
+/// Returns the corresponding range for the given `processorValue` is a GPU
+/// thread id or block dim.
+static Optional<std::pair<AffineExpr, AffineExpr>> getThreadRange(
+    Value processorValue, SmallVectorImpl<Value> & /*dims*/,
+    SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupSize) {
+  if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) {
+    OpBuilder builder(processorValue.getContext());
+    unsigned index = dimToIndex(idOp.dimension());
+    AffineExpr zero = builder.getAffineConstantExpr(0);
+    AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupSize[index]);
+    return std::make_pair(zero, ubExpr - 1);
+  }
+  if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) {
+    OpBuilder builder(processorValue.getContext());
+    unsigned index = dimToIndex(dimOp.dimension());
+    AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]);
+    return std::make_pair(bound, bound);
+  }
+  return llvm::None;
+}
+
 //====---------------------------------------------------------------------===//
 // Patterns for vectorization
 //====---------------------------------------------------------------------===//
@@ -434,7 +467,7 @@
       }
     }
     LLVM_DEBUG({
-      llvm::dbgs() << "--- After Vector Unroll ---\n";
+      llvm::dbgs() << "--- After unrolling vector ---\n";
       funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
       llvm::dbgs() << "\n\n";
     });
@@ -444,7 +477,7 @@
     linalg::hoistRedundantVectorTransfers(funcOp);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "--- After Hoisting ---\n";
+      llvm::dbgs() << "--- After hoisting vector transfers ---\n";
       funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
       llvm::dbgs() << "\n\n";
     });
@@ -483,6 +516,30 @@
 }
 
 //====---------------------------------------------------------------------===//
+// Patterns to lower linalg ops to loops
+//====---------------------------------------------------------------------===//
+
+template <typename OpTy>
+struct LowerToLoops final : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Only handle the cases where tiling to invocations was done, where tiling
+    // convolution filters or vectorization is expected.
+    if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()}))
+      return failure();
+
+    if (succeeded(linalg::linalgOpToLoops(rewriter, op))) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+//====---------------------------------------------------------------------===//
 // Main pass implementation
 //====---------------------------------------------------------------------===//
 
@@ -513,7 +570,7 @@
     LaunchConfig &launchConfig = *launchConfigOpt;
 
     LLVM_DEBUG({
-      llvm::dbgs() << "\n--- IREE Linalg tile configuration ---\n";
+      llvm::dbgs() << "\n--- Linalg tile configuration ---\n";
       llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
       interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
       llvm::dbgs() << "]\n";
@@ -542,7 +599,77 @@
       applyCanonicalizationPatternsForTiling(context, funcOp);
 
       LLVM_DEBUG({
-        llvm::dbgs() << "--- After Promotion  ---\n";
+        llvm::dbgs() << "--- After workgroup memory promotion  ---\n";
+        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
+    // TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be
+    // controlled by vectorization. This is needed due to historical reasons.
+    // Change the second level tiling to cyclic to loops and remove this.
+    if (launchConfig.useVectorize()) {
+      OwningRewritePatternList secondLevelTilingPatterns(&getContext());
+      populateTilingToSubgroupPatterns(context, launchConfig,
+                                       secondLevelTilingPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(secondLevelTilingPatterns));
+      applyCanonicalizationPatternsForTiling(context, funcOp);
+      promoteSingleIterationLoops(funcOp);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "--- After tiling to subgroups ---\n";
+        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
+    {
+      OwningRewritePatternList thirdLevelTilingPatterns(&getContext());
+      populateTilingToInvocationPatterns(context, launchConfig,
+                                         thirdLevelTilingPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(thirdLevelTilingPatterns));
+
+      // Remove trip-one loops created during cyclic loop distribution if we can
+      // prove the tiling was perfect.
+      RewritePatternSet canoncalizationPatterns(context);
+      populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
+      ArrayRef<int64_t> workgroupSize = launchConfig.getWorkgroupSize();
+      auto getThreadRangeFn = [workgroupSize](Value processorValue,
+                                              SmallVectorImpl<Value> &dims,
+                                              SmallVectorImpl<Value> &symbols) {
+        return getThreadRange(processorValue, dims, symbols, workgroupSize);
+      };
+      populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
+                                               getThreadRangeFn);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canoncalizationPatterns));
+
+      // Perform generic canonicalization.
+      applyCanonicalizationPatternsForTiling(context, funcOp);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "--- After tiling to invocations ---\n";
+        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
+    {
+      OwningRewritePatternList tilingPatterns(&getContext());
+      auto marker = getLinalgMatchAndReplaceMarker(
+          getConvFilterTileMarker(), getVectorizeMarker(), context);
+      populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
+                                       marker);
+      populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
+      tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
+          context);
+      (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
+      applyCanonicalizationPatternsForTiling(context, funcOp);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "--- After tiling convolution filter  ---\n";
         funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
         llvm::dbgs() << "\n\n";
       });
@@ -550,57 +677,6 @@
 
     if (launchConfig.useVectorize()) {
       {
-        OwningRewritePatternList secondLevelTilingPatterns(&getContext());
-        populateTilingToSubgroupPatterns(context, launchConfig,
-                                         secondLevelTilingPatterns);
-        (void)applyPatternsAndFoldGreedily(
-            funcOp, std::move(secondLevelTilingPatterns));
-        applyCanonicalizationPatternsForTiling(context, funcOp);
-        promoteSingleIterationLoops(funcOp);
-
-        LLVM_DEBUG({
-          llvm::dbgs() << "--- After Second level Tiling  ---\n";
-          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-          llvm::dbgs() << "\n\n";
-        });
-      }
-
-      {
-        OwningRewritePatternList thirdLevelTilingPatterns(&getContext());
-        populateTilingToInvocationPatterns(context, launchConfig,
-                                           thirdLevelTilingPatterns);
-        (void)applyPatternsAndFoldGreedily(funcOp,
-                                           std::move(thirdLevelTilingPatterns));
-        applyCanonicalizationPatternsForTiling(context, funcOp);
-        promoteSingleIterationLoops(funcOp);
-
-        LLVM_DEBUG({
-          llvm::dbgs() << "--- After Third level Tiling  ---\n";
-          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-          llvm::dbgs() << "\n\n";
-        });
-      }
-
-      {
-        OwningRewritePatternList tilingPatterns(&getContext());
-        auto marker = getLinalgMatchAndReplaceMarker(
-            getConvFilterTileMarker(), getVectorizeMarker(), context);
-        populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
-                                         marker);
-        populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
-        tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
-            context);
-        (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
-        applyCanonicalizationPatternsForTiling(context, funcOp);
-
-        LLVM_DEBUG({
-          llvm::dbgs() << "--- After tiling convolution filter  ---\n";
-          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-          llvm::dbgs() << "\n\n";
-        });
-      }
-
-      {
         OwningRewritePatternList vectorizationPatterns(&getContext());
         populateVectorizationPatterns(context, launchConfig,
                                       vectorizationPatterns);
@@ -608,7 +684,7 @@
         (void)applyPatternsAndFoldGreedily(funcOp,
                                            std::move(vectorizationPatterns));
         LLVM_DEBUG({
-          llvm::dbgs() << "--- After Vectorization ---\n";
+          llvm::dbgs() << "--- After vectorization ---\n";
           funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
           llvm::dbgs() << "\n\n";
         });
@@ -625,32 +701,27 @@
       applyVectorTransformation(funcOp);
     }
 
-    // Invoke patterns to generalize linalg.depthwise_conv_2d_nhwc ops to Linalg
-    // generic ops. This can handle those cases that failed tiling and
-    // vectorization in the above.
-    // TODO(antiagainst): remove this once we have depthwise convolution
-    // vectorization applicable everywhere.
+    // Lower ops that were tiled to invocations but not vectorized to loops.
+    // TODO(antiagainst): This is here now to simplify the interaction with
+    // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
+    // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
+    // directly.
     {
-      // Carry over the Linalg marker because it is load-bearing and affects
-      // later passes.
-      linalg::LinalgTransformationFilter marker =
-          getLinalgMatchAndReplaceMarker({getWorkgroupMarker()},
-                                         getWorkgroupMarker(), context);
-      marker.addFilter([](Operation *op) -> LogicalResult {
-        return success(isa<linalg::DepthwiseConvInputNHWCFilterHWCFOp,
-                           linalg::DepthwiseConvInputNHWCFilterHWCOp>(op));
-      });
-
-      OwningRewritePatternList patterns(&getContext());
-      linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns, marker);
-
+      RewritePatternSet patterns(context);
+      patterns
+          .add<LowerToLoops<linalg::BatchMatmulOp>,
+               LowerToLoops<linalg::ConvInputNWCFilterWCFOp>,
+               LowerToLoops<linalg::ConvInputNHWCFilterHWCFOp>,
+               LowerToLoops<linalg::ConvInputNDHWCFilterDHWCFOp>,
+               LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+               LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCOp>,
+               LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
+               LowerToLoops<linalg::IndexedGenericOp>,
+               LowerToLoops<linalg::MatmulOp>,
+               LowerToLoops<linalg::PoolingNHWCMaxFOp>,
+               LowerToLoops<linalg::PoolingNHWCMinFOp>,
+               LowerToLoops<linalg::PoolingNHWCSumFOp>>(context);
       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
-      LLVM_DEBUG({
-        llvm::dbgs() << "--- After generalization ---\n";
-        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-        llvm::dbgs() << "\n\n";
-      });
     }
 
     launchConfig.finalize(funcOp);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 871dbbd..481d20d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -105,9 +105,10 @@
 static linalg::ProcInfo getGPUProcessorIdAndCountImpl(OpBuilder &builder,
                                                       Location loc,
                                                       unsigned dim) {
-  std::array<StringRef, kNumGPUDims> dimAttr{"x", "y", "z"};
-  StringAttr attr =
-      builder.getStringAttr(dimAttr[std::min<unsigned>(dim, kNumGPUDims)]);
+  assert(dim < kNumGPUDims && "processor index out of range!");
+
+  std::array<const char *, kNumGPUDims> dimAttr{"x", "y", "z"};
+  StringAttr attr = builder.getStringAttr(dimAttr[dim]);
   Type indexType = builder.getIndexType();
   return {builder.create<GPUIdOp>(loc, indexType, attr),
           builder.create<GPUCountOp>(loc, indexType, attr)};
@@ -116,9 +117,10 @@
 template <>
 linalg::ProcInfo getGPUProcessorIdAndCountImpl<GPUGlobalId, GPUGlobalCount>(
     OpBuilder &builder, Location loc, unsigned dim) {
-  std::array<StringRef, kNumGPUDims> dimAttr{"x", "y", "z"};
-  StringAttr attr =
-      builder.getStringAttr(dimAttr[std::min<unsigned>(dim, kNumGPUDims)]);
+  assert(dim < kNumGPUDims && "processor index out of range!");
+
+  std::array<const char *, kNumGPUDims> dimAttr{"x", "y", "z"};
+  StringAttr attr = builder.getStringAttr(dimAttr[dim]);
   Type indexType = builder.getIndexType();
   Value gridDim = builder.create<gpu::GridDimOp>(loc, indexType, attr);
   Value blockId = builder.create<gpu::BlockIdOp>(loc, indexType, attr);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index fb46d67..2ec44c4 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -89,7 +89,8 @@
   });
   target->markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   OwningRewritePatternList tileAndDistributePattern(&getContext());
-  populateLinalgTileAndDistributePatterns(context, tileAndDistributePattern);
+  populateTileAndDistributeLinalgCopyPatterns(context,
+                                              tileAndDistributePattern);
   if (failed(applyPartialConversion(funcOp, *target,
                                     std::move(tileAndDistributePattern)))) {
     return signalPassFailure();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
index 9d5f405..9fc7636 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
@@ -40,6 +40,7 @@
             "promote_workgroup_memory.mlir",
             "tile_and_vectorize_batch_matmul.mlir",
             "tile_and_vectorize_conv.mlir",
+            "tile_and_vectorize_in_one_workgroup.mlir",
             "tile_and_vectorize_matmul.mlir",
             "vector_to_cooperative_matrix.mlir",
             "vector_to_gpu.mlir",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
index 41b78b1..dedf5f1 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
@@ -27,6 +27,7 @@
     "promote_workgroup_memory.mlir"
     "tile_and_vectorize_batch_matmul.mlir"
     "tile_and_vectorize_conv.mlir"
+    "tile_and_vectorize_in_one_workgroup.mlir"
     "tile_and_vectorize_matmul.mlir"
     "vector_to_cooperative_matrix.mlir"
     "vector_to_gpu.mlir"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index edd7e9a..8ae6c06 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -275,388 +275,3 @@
 //       CHECK:         scf.for %[[IV1:.+]] = %{{.+}} to %[[C75]]
 //   CHECK-DAG:           %[[ISZERO0:.+]] = cmpi eq, %[[IV0]], %[[C0]]
 //   CHECK-DAG:           %[[ISZERO1:.+]] = cmpi eq, %[[IV1]], %[[C0]]
-
-// -----
-
-#map0 = affine_map<()[s0] -> (s0 * 8)>
-#map1 = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
-#map2 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-#map4 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-hal.executable @matmul attributes {sym_visibility = "private"} {
-  hal.interface @io {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-  }
-  hal.executable.target @vulkan, filter="vulkan*" {
-    hal.executable.entry_point @matmul attributes {
-      interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
-    module attributes {
-      spv.target_env =
-        #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-                        {max_compute_workgroup_invocations = 128 : i32,
-                         max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-      func @matmul() {
-        %c0 = constant 0 : index
-        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
-        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?xf32>
-        %arg2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
-        %c4 = constant 4 : index
-        %c1 = constant 1 : index
-        %0 = memref.dim %arg0, %c1 : memref<?x?xf32>
-        %1 = "gpu.block_id"() {dimension = "x"} : () -> index
-        %2 = "gpu.block_id"() {dimension = "y"} : () -> index
-        scf.for %arg3 = %c0 to %0 step %c4 {
-          %3 = affine.apply #map0()[%2]
-          %4 = memref.dim %arg0, %c0 : memref<?x?xf32>
-          %5 = affine.min #map1()[%2, %4]
-          %6 = affine.min #map2(%arg3)[%0]
-          %7 = memref.subview %arg0[%3, %arg3] [%5, %6] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
-          %8 = memref.dim %arg1, %c0 : memref<?x?xf32>
-          %9 = affine.min #map2(%arg3)[%8]
-          %10 = affine.apply #map0()[%1]
-          %11 = memref.dim %arg1, %c1 : memref<?x?xf32>
-          %12 = affine.min #map1()[%1, %11]
-          %13 = memref.subview %arg1[%arg3, %10] [%9, %12] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
-          %14 = memref.dim %arg2, %c0 : memref<?x?xf32>
-          %15 = affine.min #map1()[%2, %14]
-          %16 = memref.dim %arg2, %c1 : memref<?x?xf32>
-          %17 = affine.min #map1()[%1, %16]
-          %18 = memref.subview %arg2[%3, %10] [%15, %17] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
-          linalg.matmul {__internal_linalg_transform__ = "workgroup"}
-            ins(%7, %13 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
-           outs(%18 : memref<?x?xf32, #map3>)
-        }
-        return
-      }
-      hal.interface @io attributes {sym_visibility = "private"} {
-        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-      }
-    }
-  }
-}
-// CHECK-LABEL: func @matmul
-//   CHECK-DAG:   %[[C0:.+]] = constant 0
-//   CHECK-DAG:   %[[C1:.+]] = constant 1
-//       CHECK:   scf.for
-//   CHECK-DAG:     %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//   CHECK-DAG:     %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//   CHECK-DAG:     %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//   CHECK-DAG:     %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
-//       CHECK:     scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//       CHECK:       scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-//       CHECK:         scf.for %{{.+}} = %[[C0]] to %{{.*}} step %[[C1]]
-//   CHECK-NOT:           linalg.matmul
-
-// -----
-
-
-hal.executable @conv_1d attributes {sym_visibility = "private"} {
-  hal.interface @io {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-  }
-  hal.executable.target @vulkan_spirv, filter="vulkan*" {
-    hal.executable.entry_point @conv_1d attributes {interface = @io, ordinal = 0 : index, signature = (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>}
-    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
-      func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
-        %cst = constant 0.000000e+00 : f32
-        %c0 = constant 0 : index
-        %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<3x6x1xf32>
-        %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<3x8x1xf32>
-        %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x1x1xf32>
-        %3 = "gpu.block_id"() {dimension = "x"} : () -> index
-        %4 = "gpu.block_id"() {dimension = "y"} : () -> index
-        %5 = "gpu.block_id"() {dimension = "z"} : () -> index
-        %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %7 = affine.min affine_map<()[s0] -> (6, s0 * -4 + 8)>()[%4]
-        %8 = memref.subview %1[%5, %6, 0] [1, %7, 1] [1, 1, 1] : memref<3x8x1xf32> to memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>
-        %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %10 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 1)>()[%3]
-        %11 = memref.subview %2[0, 0, %9] [3, 1, %10] [1, 1, 1] : memref<3x1x1xf32> to memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>
-        %12 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %13 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 6)>()[%4]
-        %14 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %15 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 1)>()[%3]
-        %16 = memref.subview %0[%5, %12, %14] [1, %13, %15] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
-        %17 = memref.subview %0[%5, %12, %9] [1, %13, %10] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
-        linalg.conv_1d_input_nwc_filter_wcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%8, %11 : memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>, memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>) outs(%16 : memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>)
-        return
-      }
-      hal.interface @io attributes {sym_visibility = "private"} {
-        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-      }
-    }
-  }
-}
-
-//         CHECK: func @conv_1d
-// CHECK-COUNT-4:   scf.for
-//     CHECK-NOT:     linalg.conv_1d_input_nwc_filter_wcf
-
-// -----
-
-#map0 = affine_map<()[s0] -> (s0 * 4)>
-#map1 = affine_map<()[s0] -> (s0 * 32)>
-#map2 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
-#map3 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
-#map4 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
-#map5 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-#map6 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map7 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-
-
-hal.executable @conv_no_padding attributes {sym_visibility = "private"} {
-  hal.interface @io {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-  }
-  hal.executable.target @vulkan, filter="vulkan*" {
-    hal.executable.entry_point @conv_no_padding attributes {
-      interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
-    module attributes {
-      spv.target_env =
-        #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-                        {max_compute_workgroup_invocations = 128 : i32,
-                         max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-      func @conv_no_padding() {
-        %c0 = constant 0 : index
-        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?x?x?xf32>
-        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?x?x?xf32>
-        %arg2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?x?x?xf32>
-        %c2 = constant 2 : index
-        %c3 = constant 3 : index
-        %c1 = constant 1 : index
-        %0 = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
-        %1 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
-        %2 = memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
-        %3 = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
-        %4 = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-        %5 = "gpu.block_id"() {dimension = "x"} : () -> index
-        %6 = "gpu.grid_dim"() {dimension = "x"} : () -> index
-        %7 = "gpu.block_id"() {dimension = "y"} : () -> index
-        %8 = "gpu.grid_dim"() {dimension = "y"} : () -> index
-        %9 = "gpu.block_id"() {dimension = "z"} : () -> index
-        %10 = "gpu.grid_dim"() {dimension = "z"} : () -> index
-        %11 = affine.apply #map0()[%7]
-        %12 = affine.apply #map0()[%8]
-        %13 = affine.apply #map1()[%5]
-        %14 = affine.apply #map1()[%6]
-        scf.parallel (%arg3, %arg4, %arg5) = (%9, %11, %13) to (%2, %3, %4) step (%10, %12, %14) {
-          %15 = affine.min #map2(%arg3)[%2]
-          %16 = memref.dim %arg1, %c1 : memref<?x?x?x?xf32>
-          %17 = affine.min #map3(%arg4)[%0, %16]
-          %18 = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
-          %19 = affine.min #map4(%arg5)[%1, %18]
-          %20 = memref.dim %arg1, %c3 : memref<?x?x?x?xf32>
-          %21 = memref.subview %arg1[%arg3, %arg4, %arg5, 0] [%15, %17, %19, %20] [1, 1, 1, 1]
-                  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
-          %22 = memref.dim %arg2, %c0 : memref<?x?x?x?xf32>
-          %23 = affine.min #map2(%arg3)[%22]
-          %24 = affine.min #map6(%arg4)[%3]
-          %25 = affine.min #map7(%arg5)[%4]
-          %26 = memref.dim %arg2, %c3 : memref<?x?x?x?xf32>
-          %27 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [%23, %24, %25, %26] [1, 1, 1, 1]
-                  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
-          linalg.conv_2d_input_nhwc_filter_hwcf {
-            __internal_linalg_transform__ = "workgroup",
-            dilations = dense<1> : tensor<2xi64>,
-            strides = dense<2> : tensor<2xi64>}
-             ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
-            outs(%27 : memref<?x?x?x?xf32, #map5>)
-          scf.yield
-        }
-        return
-      }
-      hal.interface @io attributes {sym_visibility = "private"} {
-        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-      }
-    }
-  }
-}
-//     CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
-//     CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
-//         CHECK: func @conv_no_padding
-//     CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
-//     CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
-//     CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
-//     CHECK-DAG:   %[[C1:.+]] = constant 1
-//     CHECK-DAG:   %[[C2:.+]] = constant 2
-//     CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C0]]
-//     CHECK-DAG:   %[[P:.+]] = memref.dim %[[RET0]], %[[C1]]
-//     CHECK-DAG:   %[[Q:.+]] = memref.dim %[[RET0]], %[[C2]]
-//     CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//     CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-//     CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//     CHECK-DAG:   %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-//     CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
-//     CHECK-DAG:   %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
-//         CHECK:   %[[BOFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//         CHECK:   %[[BSTEPY:.+]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
-//         CHECK:   %[[BOFFSETX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//         CHECK:   %[[BSTEPX:.+]] = affine.apply #[[MAP1]]()[%[[NBLOCKSX]]]
-//         CHECK:   scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]]
-//         CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
-//         CHECK:       scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
-//         CHECK:         %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
-//         CHECK:         %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
-//     CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//     CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//     CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
-//     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
-//     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-4:               scf.for
-//     CHECK-NOT:               linalg.conv_2d_input_nhwc_filter_hwcf
-
-// -----
-
-hal.executable @conv_3d attributes {sym_visibility = "private"} {
-  hal.interface @io {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-  }
-  hal.executable.target @vulkan_spirv, filter="vulkan*" {
-    hal.executable.entry_point @conv_3d attributes {interface = @io, ordinal = 0 : index, signature = (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>}
-    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
-      func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
-        %cst = constant 0.000000e+00 : f32
-        %c0 = constant 0 : index
-        %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x7x7x7x2xf32>
-        %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x8x8x8x3xf32>
-        %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<2x2x2x3x2xf32>
-        %3 = "gpu.block_id"() {dimension = "x"} : () -> index
-        %4 = "gpu.block_id"() {dimension = "y"} : () -> index
-        %5 = "gpu.block_id"() {dimension = "z"} : () -> index
-        %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %7 = affine.min affine_map<()[s0] -> (5, s0 * -4 + 8)>()[%4]
-        %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %9 = affine.min affine_map<()[s0] -> (33, s0 * -32 + 8)>()[%3]
-        %10 = memref.subview %1[%5, %6, %8, 0, 0] [1, %7, %9, 8, 3] [1, 1, 1, 1, 1] : memref<2x8x8x8x3xf32> to memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>
-        %11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %12 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 7)>()[%4]
-        %13 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 7)>()[%3]
-        %15 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
-        %16 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
-        linalg.conv_3d_input_ndhwc_filter_dhwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%10, %2 : memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>, memref<2x2x2x3x2xf32>) outs(%15 : memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>)
-        return
-      }
-      hal.interface @io attributes {sym_visibility = "private"} {
-        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-      }
-    }
-  }
-}
-
-//         CHECK: func @conv_3d
-//     CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//     CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//     CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
-//     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
-//     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-6:               scf.for
-//     CHECK-NOT:               linalg.conv_3d_input_ndhwc_filter_dhwcf
-
-// -----
-
-#map0 = affine_map<()[s0] -> (s0 * 4)>
-#map1 = affine_map<()[s0] -> (6, s0 * -4 + 16)>
-#map2 = affine_map<()[s0] -> (s0 * 32)>
-#map3 = affine_map<()[s0] -> (35, s0 * -32 + 16)>
-#map4 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 96 + d2 * 6 + d3)>
-#map5 = affine_map<()[s0] -> (4, s0 * -4 + 14)>
-#map6 = affine_map<()[s0] -> (32, s0 * -32 + 13)>
-#map7 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1092 + s0 + d1 * 78 + d2 * 6 + d3)>
-module  {
-  hal.executable @pooling_nhwc_max attributes {sym_visibility = "private"} {
-    hal.interface @io {
-      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-    }
-    hal.executable.target @vulkan, filter="vulkan*" {
-      hal.executable.entry_point @pooling_nhwc_max attributes {interface = @io, ordinal = 0 : index, signature = (!flow.dispatch.tensor<readonly:2x16x16x6xf32>, !flow.dispatch.tensor<readonly:1x3x4x2xf32>, !flow.dispatch.tensor<writeonly:2x14x13x5xf32>) -> ()} {
-      ^bb0(%arg0: index, %arg1: index, %arg2: index):  // no predecessors
-        %c4 = constant 4 : index
-        %c1 = constant 1 : index
-        hal.return %c1, %c4, %c1 : index, index, index
-      }
-      module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>}  {
-        func @pooling_nhwc_max() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
-          %c0 = constant 0 : index
-          %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x16x16x6xf32>
-          %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x4xf32>
-          %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x14x13x6xf32>
-          %3 = "gpu.block_id"() {dimension = "x"} : () -> index
-          %4 = "gpu.block_id"() {dimension = "y"} : () -> index
-          %5 = affine.apply #map0()[%4]
-          %6 = affine.min #map1()[%4]
-          %7 = affine.apply #map2()[%3]
-          %8 = affine.min #map3()[%3]
-          %9 = memref.subview %0[0, %5, %7, 0] [2, %6, %8, 6] [1, 1, 1, 1] : memref<2x16x16x6xf32> to memref<2x?x?x6xf32, #map4>
-          %10 = affine.min #map5()[%4]
-          %11 = affine.min #map6()[%3]
-          %12 = memref.subview %2[0, %5, %7, 0] [2, %10, %11, 6] [1, 1, 1, 1] : memref<2x14x13x6xf32> to memref<2x?x?x6xf32, #map7>
-          linalg.pooling_nhwc_max {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%9, %1 : memref<2x?x?x6xf32, #map4>, memref<3x4xf32>) outs(%12 : memref<2x?x?x6xf32, #map7>)
-          return
-        }
-        hal.interface @io attributes {sym_visibility = "private"} {
-          hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-          hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-          hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-        }
-      }
-    }
-  }
-}
-
-//     CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
-//     CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
-//         CHECK: func @pooling_nhwc_max
-//     CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
-//     CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
-//     CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
-//     CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//     CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//         CHECK:   %[[IV1:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//         CHECK:   %[[IV2:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
-//         CHECK:   %[[SV1:.+]] = memref.subview %[[ARG0]][0, %[[IV1]], %[[IV2]], 0]
-//         CHECK:   %[[SV2:.+]] = memref.subview %[[RET0]][0, %[[IV1]], %[[IV2]], 0]
-//     CHECK-DAG:   %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//     CHECK-DAG:   %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//     CHECK-DAG:   %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
-//     CHECK-DAG:   %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//     CHECK-DAG:   %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
-//     CHECK-DAG:   %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:   scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:     scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:       scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-3:         scf.for
-//     CHECK-NOT:           linalg.pooling_nhwc_max
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/promote_workgroup_memory.mlir
index a6367c6..8ffe4ec 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/promote_workgroup_memory.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/promote_workgroup_memory.mlir
@@ -67,10 +67,11 @@
 //  CHECK-SAME:       "copy_to_workgroup_memory"
 //       CHECK:     linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
 //  CHECK-SAME:       "copy_to_workgroup_memory"
-//       CHECK:     linalg.matmul
-//  CHECK-SAME:       "workgroup_memory"
-//  CHECK-SAME:       ins(%[[SUBVIEW1]], %[[SUBVIEW2]]
-//  CHECK-SAME:      outs(%[[RET0SV]]
+//       CHECK:     scf.for
+//       CHECK:       scf.for
+//   CHECK-DAG:         memref.subview %[[SUBVIEW1]]
+//   CHECK-DAG:         memref.subview %[[SUBVIEW2]]
+//   CHECK-DAG:         memref.subview %[[RET0SV]]
 
 // -----
 
@@ -129,7 +130,9 @@
 //       CHECK:   %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC1]]
 //       CHECK:   linalg.copy(%[[ARG1SV]], %[[SUBVIEW1]])
 //  CHECK-SAME:      "copy_to_workgroup_memory"
-//       CHECK:   linalg.conv_2d_input_nhwc_filter_hwcf
-//  CHECK-SAME:     "workgroup_memory"
-//  CHECK-SAME:     ins(%[[SUBVIEW1]], %[[ARG0]]
-//  CHECK-SAME:    outs(%[[RET0SV]]
+//       CHECK:   scf.for
+//       CHECK:     scf.for
+//       CHECK:       scf.for
+//   CHECK-DAG:         memref.subview %[[SUBVIEW1]]
+//   CHECK-DAG:         memref.subview %[[ARG0]]
+//   CHECK-DAG:         memref.subview %[[RET0SV]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_in_one_workgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_in_one_workgroup.mlir
new file mode 100644
index 0000000..58bf934
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_in_one_workgroup.mlir
@@ -0,0 +1,408 @@
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" %s | IreeFileCheck %s
+
+#map0 = affine_map<()[s0] -> (s0 * 8)>
+#map1 = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
+#map2 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map4 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+hal.executable @matmul attributes {sym_visibility = "private"} {
+  hal.interface @io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan, filter="vulkan*" {
+    hal.executable.entry_point @matmul attributes {
+      interface = @io, ordinal = 0 : index,
+      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
+        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
+    module attributes {
+      spv.target_env =
+        #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+                        {max_compute_workgroup_invocations = 128 : i32,
+                         max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+      func @matmul() {
+        %c0 = constant 0 : index
+        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
+        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?xf32>
+        %arg2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
+        %c4 = constant 4 : index
+        %c1 = constant 1 : index
+        %0 = memref.dim %arg0, %c1 : memref<?x?xf32>
+        %1 = "gpu.block_id"() {dimension = "x"} : () -> index
+        %2 = "gpu.block_id"() {dimension = "y"} : () -> index
+        scf.for %arg3 = %c0 to %0 step %c4 {
+          %3 = affine.apply #map0()[%2]
+          %4 = memref.dim %arg0, %c0 : memref<?x?xf32>
+          %5 = affine.min #map1()[%2, %4]
+          %6 = affine.min #map2(%arg3)[%0]
+          %7 = memref.subview %arg0[%3, %arg3] [%5, %6] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
+          %8 = memref.dim %arg1, %c0 : memref<?x?xf32>
+          %9 = affine.min #map2(%arg3)[%8]
+          %10 = affine.apply #map0()[%1]
+          %11 = memref.dim %arg1, %c1 : memref<?x?xf32>
+          %12 = affine.min #map1()[%1, %11]
+          %13 = memref.subview %arg1[%arg3, %10] [%9, %12] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
+          %14 = memref.dim %arg2, %c0 : memref<?x?xf32>
+          %15 = affine.min #map1()[%2, %14]
+          %16 = memref.dim %arg2, %c1 : memref<?x?xf32>
+          %17 = affine.min #map1()[%1, %16]
+          %18 = memref.subview %arg2[%3, %10] [%15, %17] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
+          linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%7, %13 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
+           outs(%18 : memref<?x?xf32, #map3>)
+        }
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+// CHECK-LABEL: func @matmul
+//   CHECK-DAG:   %[[C0:.+]] = constant 0
+//   CHECK-DAG:   %[[C1:.+]] = constant 1
+//       CHECK:   scf.for
+//   CHECK-DAG:     %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//   CHECK-DAG:     %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//   CHECK-DAG:     %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//   CHECK-DAG:     %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//       CHECK:     scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//       CHECK:       scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+//       CHECK:         scf.for %{{.+}} = %[[C0]] to %{{.*}} step %[[C1]]
+//   CHECK-NOT:           linalg.matmul
+
+// -----
+
+hal.executable @conv_1d attributes {sym_visibility = "private"} {
+  hal.interface @io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
+    hal.executable.entry_point @conv_1d attributes {interface = @io, ordinal = 0 : index, signature = (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>}
+    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
+      func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+        %cst = constant 0.000000e+00 : f32
+        %c0 = constant 0 : index
+        %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<3x6x1xf32>
+        %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<3x8x1xf32>
+        %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x1x1xf32>
+        %3 = "gpu.block_id"() {dimension = "x"} : () -> index
+        %4 = "gpu.block_id"() {dimension = "y"} : () -> index
+        %5 = "gpu.block_id"() {dimension = "z"} : () -> index
+        %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
+        %7 = affine.min affine_map<()[s0] -> (6, s0 * -4 + 8)>()[%4]
+        %8 = memref.subview %1[%5, %6, 0] [1, %7, 1] [1, 1, 1] : memref<3x8x1xf32> to memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>
+        %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
+        %10 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 1)>()[%3]
+        %11 = memref.subview %2[0, 0, %9] [3, 1, %10] [1, 1, 1] : memref<3x1x1xf32> to memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>
+        %12 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
+        %13 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 6)>()[%4]
+        %14 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
+        %15 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 1)>()[%3]
+        %16 = memref.subview %0[%5, %12, %14] [1, %13, %15] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
+        %17 = memref.subview %0[%5, %12, %9] [1, %13, %10] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
+        linalg.conv_1d_input_nwc_filter_wcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%8, %11 : memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>, memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>) outs(%16 : memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>)
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: func @conv_1d
+//   CHECK-DAG: %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG: %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG: %[[C3:.+]] = constant 3 : index
+//       CHECK: %[[RET:.+]] = hal.interface.binding.subspan @io::@ret0
+//       CHECK: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+//       CHECK: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
+//       CHECK: %[[ARG0SV1:.+]] = memref.subview %[[ARG0]]
+//       CHECK: %[[ARG1SV1:.+]] = memref.subview %[[ARG1]]
+//       CHECK: %[[RETSV1:.+]] = memref.subview %[[RET]]
+//       CHECK: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//       CHECK: %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//       CHECK: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//       CHECK: %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//       CHECK: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//       CHECK: scf.for %[[IV0:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//       CHECK:   scf.for %[[IV1:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+//       CHECK:     %[[ARG0SV2:.+]] = memref.subview %[[ARG0SV1]][%[[TIDZ]], %[[IV0]], 0] [1, %{{.+}}, 1]
+//       CHECK:     %[[ARG1SV2:.+]] = memref.subview %[[ARG1SV1]][0, 0, %[[IV1]]] [3, 1, 1]
+//       CHECK:     %[[RETSV2:.+]] = memref.subview %[[RETSV1]][%[[TIDZ]], %[[IV0]], %[[IV1]]] [1, 1, 1]
+//       CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[C3]] step %[[C1]]
+//       CHECK:       memref.load %[[ARG0SV2]][%[[C0]], %[[IV2]], %[[C0]]]
+//       CHECK:       memref.load %[[ARG1SV2]][%[[IV2]], %[[C0]], %[[C0]]]
+//       CHECK:       memref.load %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]]
+//       CHECK:       memref.store %{{.+}}, %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]]
+
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 4)>
+#map1 = affine_map<()[s0] -> (s0 * 32)>
+#map2 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+#map3 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map4 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map5 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+#map6 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map7 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+
+hal.executable @conv_no_padding attributes {sym_visibility = "private"} {
+  hal.interface @io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan, filter="vulkan*" {
+    hal.executable.entry_point @conv_no_padding attributes {
+      interface = @io, ordinal = 0 : index,
+      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
+        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
+    module attributes {
+      spv.target_env =
+        #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+                        {max_compute_workgroup_invocations = 128 : i32,
+                         max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+      func @conv_no_padding() {
+        %c0 = constant 0 : index
+        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?x?x?xf32>
+        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?x?x?xf32>
+        %arg2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?x?x?xf32>
+        %c2 = constant 2 : index
+        %c3 = constant 3 : index
+        %c1 = constant 1 : index
+        %0 = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
+        %1 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
+        %2 = memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
+        %3 = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
+        %4 = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
+        %5 = "gpu.block_id"() {dimension = "x"} : () -> index
+        %6 = "gpu.grid_dim"() {dimension = "x"} : () -> index
+        %7 = "gpu.block_id"() {dimension = "y"} : () -> index
+        %8 = "gpu.grid_dim"() {dimension = "y"} : () -> index
+        %9 = "gpu.block_id"() {dimension = "z"} : () -> index
+        %10 = "gpu.grid_dim"() {dimension = "z"} : () -> index
+        %11 = affine.apply #map0()[%7]
+        %12 = affine.apply #map0()[%8]
+        %13 = affine.apply #map1()[%5]
+        %14 = affine.apply #map1()[%6]
+        scf.for %arg3 = %9 to %2 step %10 {
+          scf.for %arg4 = %11 to %3 step %12 {
+            scf.for %arg5 = %13 to %4 step %14 {
+              %15 = affine.min #map2(%arg3)[%2]
+              %16 = memref.dim %arg1, %c1 : memref<?x?x?x?xf32>
+              %17 = affine.min #map3(%arg4)[%0, %16]
+              %18 = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
+              %19 = affine.min #map4(%arg5)[%1, %18]
+              %20 = memref.dim %arg1, %c3 : memref<?x?x?x?xf32>
+              %21 = memref.subview %arg1[%arg3, %arg4, %arg5, 0] [%15, %17, %19, %20] [1, 1, 1, 1]
+                      : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
+              %22 = memref.dim %arg2, %c0 : memref<?x?x?x?xf32>
+              %23 = affine.min #map2(%arg3)[%22]
+              %24 = affine.min #map6(%arg4)[%3]
+              %25 = affine.min #map7(%arg5)[%4]
+              %26 = memref.dim %arg2, %c3 : memref<?x?x?x?xf32>
+              %27 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [%23, %24, %25, %26] [1, 1, 1, 1]
+                      : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
+              linalg.conv_2d_input_nhwc_filter_hwcf {
+                __internal_linalg_transform__ = "workgroup",
+                dilations = dense<1> : tensor<2xi64>,
+                strides = dense<2> : tensor<2xi64>}
+                 ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
+                outs(%27 : memref<?x?x?x?xf32, #map5>)
+            }
+          }
+        }
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+//     CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
+//     CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//         CHECK: func @conv_no_padding
+//     CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+//     CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
+//     CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+//     CHECK-DAG:   %[[C1:.+]] = constant 1
+//     CHECK-DAG:   %[[C2:.+]] = constant 2
+//     CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C0]]
+//     CHECK-DAG:   %[[P:.+]] = memref.dim %[[RET0]], %[[C1]]
+//     CHECK-DAG:   %[[Q:.+]] = memref.dim %[[RET0]], %[[C2]]
+//     CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//     CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+//     CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//     CHECK-DAG:   %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+//     CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//     CHECK-DAG:   %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
+//         CHECK:   %[[BOFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//         CHECK:   %[[BSTEPY:.+]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
+//         CHECK:   %[[BOFFSETX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+//         CHECK:   %[[BSTEPX:.+]] = affine.apply #[[MAP1]]()[%[[NBLOCKSX]]]
+//         CHECK:   scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]]
+//         CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
+//         CHECK:       scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
+//         CHECK:         %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//         CHECK:         %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//     CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//     CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//     CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
+//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+// CHECK-COUNT-3:               scf.for
+//     CHECK-NOT:               linalg.conv_2d_input_nhwc_filter_hwcf
+
+// -----
+
+hal.executable @conv_3d attributes {sym_visibility = "private"} {
+  hal.interface @io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
+    hal.executable.entry_point @conv_3d attributes {interface = @io, ordinal = 0 : index, signature = (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>}
+    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
+      func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+        %cst = constant 0.000000e+00 : f32
+        %c0 = constant 0 : index
+        %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x7x7x7x2xf32>
+        %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x8x8x8x3xf32>
+        %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<2x2x2x3x2xf32>
+        %3 = "gpu.block_id"() {dimension = "x"} : () -> index
+        %4 = "gpu.block_id"() {dimension = "y"} : () -> index
+        %5 = "gpu.block_id"() {dimension = "z"} : () -> index
+        %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
+        %7 = affine.min affine_map<()[s0] -> (5, s0 * -4 + 8)>()[%4]
+        %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
+        %9 = affine.min affine_map<()[s0] -> (33, s0 * -32 + 8)>()[%3]
+        %10 = memref.subview %1[%5, %6, %8, 0, 0] [1, %7, %9, 8, 3] [1, 1, 1, 1, 1] : memref<2x8x8x8x3xf32> to memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>
+        %11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
+        %12 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 7)>()[%4]
+        %13 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
+        %14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 7)>()[%3]
+        %15 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
+        %16 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
+        linalg.conv_3d_input_ndhwc_filter_dhwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%10, %2 : memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>, memref<2x2x2x3x2xf32>) outs(%15 : memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>)
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+//   CHECK-LABEL: func @conv_3d
+//     CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//     CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//     CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
+//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+// CHECK-COUNT-5:               scf.for
+//     CHECK-NOT:               linalg.conv_3d_input_ndhwc_filter_dhwcf
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 4)>
+#map1 = affine_map<()[s0] -> (6, s0 * -4 + 16)>
+#map2 = affine_map<()[s0] -> (s0 * 32)>
+#map3 = affine_map<()[s0] -> (35, s0 * -32 + 16)>
+#map4 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 96 + d2 * 6 + d3)>
+#map5 = affine_map<()[s0] -> (4, s0 * -4 + 14)>
+#map6 = affine_map<()[s0] -> (32, s0 * -32 + 13)>
+#map7 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1092 + s0 + d1 * 78 + d2 * 6 + d3)>
+module  {
+  hal.executable @pooling_nhwc_max attributes {sym_visibility = "private"} {
+    hal.interface @io {
+      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+    hal.executable.target @vulkan, filter="vulkan*" {
+      hal.executable.entry_point @pooling_nhwc_max attributes {interface = @io, ordinal = 0 : index, signature = (!flow.dispatch.tensor<readonly:2x16x16x6xf32>, !flow.dispatch.tensor<readonly:1x3x4x2xf32>, !flow.dispatch.tensor<writeonly:2x14x13x5xf32>) -> ()} {
+      ^bb0(%arg0: index, %arg1: index, %arg2: index):  // no predecessors
+        %c4 = constant 4 : index
+        %c1 = constant 1 : index
+        hal.return %c1, %c4, %c1 : index, index, index
+      }
+      module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>}  {
+        func @pooling_nhwc_max() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+          %c0 = constant 0 : index
+          %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x16x16x6xf32>
+          %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x4xf32>
+          %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x14x13x6xf32>
+          %3 = "gpu.block_id"() {dimension = "x"} : () -> index
+          %4 = "gpu.block_id"() {dimension = "y"} : () -> index
+          %5 = affine.apply #map0()[%4]
+          %6 = affine.min #map1()[%4]
+          %7 = affine.apply #map2()[%3]
+          %8 = affine.min #map3()[%3]
+          %9 = memref.subview %0[0, %5, %7, 0] [2, %6, %8, 6] [1, 1, 1, 1] : memref<2x16x16x6xf32> to memref<2x?x?x6xf32, #map4>
+          %10 = affine.min #map5()[%4]
+          %11 = affine.min #map6()[%3]
+          %12 = memref.subview %2[0, %5, %7, 0] [2, %10, %11, 6] [1, 1, 1, 1] : memref<2x14x13x6xf32> to memref<2x?x?x6xf32, #map7>
+          linalg.pooling_nhwc_max {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%9, %1 : memref<2x?x?x6xf32, #map4>, memref<3x4xf32>) outs(%12 : memref<2x?x?x6xf32, #map7>)
+          return
+        }
+        hal.interface @io attributes {sym_visibility = "private"} {
+          hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+          hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+          hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+        }
+      }
+    }
+  }
+}
+
+//     CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
+//     CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//         CHECK: func @pooling_nhwc_max
+//     CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+//     CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
+//     CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+//     CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//     CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//         CHECK:   %[[IV1:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//         CHECK:   %[[IV2:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
+//         CHECK:   %[[SV1:.+]] = memref.subview %[[ARG0]][0, %[[IV1]], %[[IV2]], 0]
+//         CHECK:   %[[SV2:.+]] = memref.subview %[[RET0]][0, %[[IV1]], %[[IV2]], 0]
+//     CHECK-DAG:   %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//     CHECK-DAG:   %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//     CHECK-DAG:   %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//     CHECK-DAG:   %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
+//     CHECK-DAG:   %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
+//     CHECK-DAG:   %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
+//         CHECK:   scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:     scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:       scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+// CHECK-COUNT-3:         scf.for
+//     CHECK-NOT:           linalg.pooling_nhwc_max
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
index b0e3c2b..451f4a3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
@@ -57,7 +57,8 @@
 // transpose.
 // CHECK-LABEL: func @elementwise_transpose
 //   CHECK-NOT:   vector.transfer_read
-//       CHECK:   linalg.generic
+//       CHECK:   scf.for
+//       CHECK:     scf.for
 hal.executable @elementwise_transpose attributes {sym_visibility = "private"} {
   hal.interface @io {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"